[Snippets] BrgemmEmitter: blocking by K & N dimensions (#18302)
* KN blocking draft * some fixes * IdentifyBuffers temporary hack * Emitter cleanup * IdentifyBuffers cleanup * BrgemmCopyB validation: N_blk is taken from the child BrgemmCPU * Added blocking parameters to BrgemCPU * accuracy fixes * Buffers insertion removed from BrgemmToBrgemmCPU * Added blocking parameters to BrgemmCopyB * blocking parameters * Blocking params configuration removed from brgemm_to_brgemm_cpu transformation * Introduced a transformation for blocking parameters configuration * MHA tokenization alligned with blocking matmul requirements * Alexandra's comments applied * Alexandra's comments applied: 2nd round * Ivan's comments applied * MHA tokenization: removed kernel_buffer_size related heuristics
This commit is contained in:
committed by
GitHub
parent
8e671403b3
commit
205de6106b
@@ -4,9 +4,10 @@
|
||||
|
||||
#include "snippets/lowered/pass/identify_buffers.hpp"
|
||||
|
||||
#include "snippets/lowered/linear_ir.hpp"
|
||||
#include "snippets/snippets_isa.hpp"
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/lowered/linear_ir.hpp"
|
||||
#include "snippets/op/brgemm.hpp"
|
||||
#include "snippets/snippets_isa.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
@@ -20,9 +21,10 @@ inline size_t index(size_t col_num, size_t row, size_t col) {
|
||||
} // namespace
|
||||
|
||||
std::vector<bool> IdentifyBuffers::create_adjacency_matrix(const LinearIR& linear_ir, const BufferSet& buffers) const {
|
||||
// The sync point to check for adjacency is Loop because only in Loop we increment pointers.
|
||||
// So if some Buffers in the one Loop have conflict (cannot be inplace: the different ptr increment and data sizes)
|
||||
// they are called as adjacent
|
||||
// There are several sync points for adjacency check:
|
||||
// 1. Loop because only in Loop we increment pointers. So if some Buffers in the one Loop have conflict
|
||||
// (cannot be inplace: the different ptr increment and data sizes) they are called as adjacent
|
||||
// 2. Brgemm because its blocking implementation requires Buffers with unique memory on all inputs and outputs
|
||||
const auto size = buffers.size();
|
||||
// TODO: Can we use triangular matrix? Need verify using tests
|
||||
std::vector<bool> adj(size * size, false);
|
||||
@@ -49,8 +51,39 @@ std::vector<bool> IdentifyBuffers::create_adjacency_matrix(const LinearIR& linea
|
||||
}
|
||||
};
|
||||
|
||||
auto is_buffer = [](const ExpressionPort& port) {
|
||||
return ov::is_type<op::Buffer>(port.get_expr()->get_node());
|
||||
};
|
||||
|
||||
for (auto expr_it = linear_ir.cbegin(); expr_it != linear_ir.cend(); expr_it++) {
|
||||
const auto &expr = *expr_it;
|
||||
if (const auto brgemm = ov::as_type_ptr<op::Brgemm>(expr->get_node())) {
|
||||
const auto consumers = expr->get_output_port_connector(0)->get_consumers();
|
||||
|
||||
auto buffer_it = std::find_if(consumers.begin(), consumers.end(), is_buffer);
|
||||
if (buffer_it == consumers.end())
|
||||
continue;
|
||||
OPENVINO_ASSERT(std::count_if(consumers.begin(), consumers.end(), is_buffer) == 1, "Brgemm mustn't have more than 1 consumer buffer");
|
||||
|
||||
std::vector<std::shared_ptr<op::Buffer>> adjacency_buffers;
|
||||
adjacency_buffers.push_back(ov::as_type_ptr<op::Buffer>(buffer_it->get_expr()->get_node()));
|
||||
|
||||
for (const auto& input_connector : expr->get_input_port_connectors()) {
|
||||
const auto parent_node = input_connector->get_source().get_expr()->get_node();
|
||||
if (const auto neighbour_buffer = ov::as_type_ptr<op::Buffer>(parent_node)) {
|
||||
adjacency_buffers.push_back(neighbour_buffer);
|
||||
}
|
||||
}
|
||||
for (auto buffer_it = adjacency_buffers.begin(); buffer_it != adjacency_buffers.end(); ++buffer_it) {
|
||||
for (auto neighbour_it = std::next(buffer_it); neighbour_it != adjacency_buffers.end(); ++neighbour_it) {
|
||||
const auto buffer_idx = get_buffer_idx(*buffer_it);
|
||||
const auto neighbour_idx = get_buffer_idx(*neighbour_it);
|
||||
adj[index(size, neighbour_idx, buffer_idx)] = adj[index(size, buffer_idx, neighbour_idx)] = true;
|
||||
}
|
||||
}
|
||||
continue;
|
||||
}
|
||||
|
||||
const auto& loop_end = ov::as_type_ptr<op::LoopEnd>(expr->get_node());
|
||||
if (!loop_end)
|
||||
continue;
|
||||
|
||||
@@ -37,10 +37,11 @@ ov::Shape compute_allocation_shape(const LinearIR::LoopManagerPtr& loop_manager,
|
||||
const std::vector<size_t>& parent_loop_ids,
|
||||
const ov::Output<ov::Node>& parent_output,
|
||||
const int allocation_rank) {
|
||||
const size_t rank = allocation_rank >= 0 ? allocation_rank : parent_output.get_shape().size();
|
||||
ov::Shape allocation_shape(rank);
|
||||
const auto port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(parent_output);
|
||||
const auto planar_shape = utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout());
|
||||
|
||||
const size_t rank = allocation_rank >= 0 ? std::min(static_cast<size_t>(allocation_rank), planar_shape.size()) : planar_shape.size();
|
||||
ov::Shape allocation_shape(rank);
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
*(allocation_shape.rbegin() + i) = (planar_shape.rbegin() + i)->get_length();
|
||||
}
|
||||
|
||||
@@ -87,8 +87,9 @@ void snippets::op::Subgraph::init_config() {
|
||||
auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& ops) -> size_t {
|
||||
// The count of potential unique Buffers - it's hidden virtual ports as well
|
||||
// We should go through Subgraph and calculate potential non-inplace Buffers count.
|
||||
// These Buffers can be only around Loops (for example, around MatMul they may be inplace because MatMul doesn't change registers).
|
||||
// So we should check for element type size of nodes which are used Buffer to get rating from above for unique Buffer count.
|
||||
// These Buffers can be in 2 cases:
|
||||
// 1. Around Loops: we should check for element type size of nodes which use Buffer to get rating from above for unique Buffer count.
|
||||
// 2. Around MatMul: all buffers around Matmul must not be inplace because MatMul blocking implementation changes registers during computations.
|
||||
// The count is estimated because when we calculate this number, we have only original graph representation
|
||||
// and where will be Loops - we can just predict.
|
||||
// Note: The ops that create Buffers: MatMul, Transpose and Softmax (always FP32)
|
||||
@@ -120,16 +121,16 @@ auto snippets::op::Subgraph::get_estimated_buffer_count(const ov::NodeVector& op
|
||||
// They are inplace and the same so we can push precision size only once
|
||||
push_prc_size(ov::element::f32.size());
|
||||
} else if (const auto matmul = ov::as_type_ptr<ov::op::v0::MatMul>(op)) {
|
||||
// First input check is enough because MatMul requires the same prc size on inputs
|
||||
if (!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(0)) ||
|
||||
!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(1))) {
|
||||
push_prc_size(matmul->get_input_element_type(0).size());
|
||||
}
|
||||
// Since all buffers around Matmul must be unique, we explicitely add values to the vector without any checks
|
||||
if (!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(0)))
|
||||
used_precision_size.push_back(matmul->get_input_element_type(0).size());
|
||||
if (!ov::is_type<ov::op::v0::Parameter>(matmul->get_input_node_shared_ptr(1)))
|
||||
used_precision_size.push_back(matmul->get_input_element_type(1).size());
|
||||
|
||||
const auto consumers = matmul->get_output_target_inputs(0);
|
||||
if (std::none_of(consumers.begin(), consumers.end(),
|
||||
[](const ov::Input<ov::Node>& in) { return ov::is_type<ov::op::v0::Result>(in.get_node()); })) {
|
||||
push_prc_size(matmul->get_element_type().size());
|
||||
used_precision_size.push_back(matmul->get_element_type().size());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -208,10 +208,13 @@ ov::snippets::pass::TokenizeMHASnippets::TokenizeMHASnippets(const SnippetsToken
|
||||
// - Between MatMul0 and Transpose1 - At the moment operations after Transpose1 cannot be fused in Transpose Loop (to avoid performance regressions).
|
||||
// But operations after Transpose1 and before MatMul0 will be fused into one loop as well (look at first point)
|
||||
// Note: If the pass is updated, need to check the new possible branches for potential non-inplace Buffers!
|
||||
// Default value is 1 because
|
||||
// - Firstly Softmax always need to have Buffers
|
||||
// - Secondly Softmax need 2 Buffer but they can be inplace - One virtual port is enough for Softmax
|
||||
size_t buffer_count = 1;
|
||||
// Default value is 2 because
|
||||
// - Firstly, Softmax always needs Buffers
|
||||
// - Secondly, Softmax needs 2 Buffers but they can be inplace - One virtual port is enough for Softmax => buffer_count = 1
|
||||
// - Thirdly, MatMul requires unique Buffers on inputs and outputs because blocking implementation increments input/output pointers during computations
|
||||
// However, all of the Buffers are usually reused by the next MatMul and Softmax.
|
||||
// So on sufficiently large subgraphs we use only one additional unique buffer => buffer_count increments by 1
|
||||
size_t buffer_count = 2;
|
||||
std::string fused_names;
|
||||
ov::NodeVector ordered_ops;
|
||||
|
||||
|
||||
@@ -701,11 +701,13 @@ void StoreConvertEmitter::emit_isa(const std::vector<size_t> &in, const std::vec
|
||||
void StoreConvertEmitter::emit_data() const {
|
||||
store_emitter->emit_data();
|
||||
}
|
||||
size_t BrgemmEmitter::getBrgIdx(size_t kIdx, size_t nIdx) const {
|
||||
return kIdx * 2 + nIdx;
|
||||
size_t BrgemmEmitter::getBrgIdx(size_t kIdx, size_t nIdx) {
|
||||
return kIdx * BRGEMM_N_KERNEL_NUM + nIdx;
|
||||
}
|
||||
BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa,
|
||||
const std::shared_ptr<ov::Node>& node) : jit_emitter(h, isa, node) {
|
||||
m_brgCtxs.fill(brgemmCtx());
|
||||
std::generate(m_brgKernels.begin(), m_brgKernels.end(), [](){ return nullptr; });
|
||||
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
|
||||
const auto& brgemm_node = as_type_ptr<ov::intel_cpu::BrgemmCPU>(node);
|
||||
if (brgemm_node->is_dynamic())
|
||||
@@ -764,44 +766,69 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
|
||||
|
||||
auto brg0Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(0));
|
||||
auto brg1Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(1));
|
||||
io_data_size = {brg0Prc.size(), brg1Prc.size(), brgemm_node->get_output_element_type(0).size()};
|
||||
m_brg0VnniFactor = 4 / brg0Prc.size();
|
||||
bool brgWithAMX = brgemm_node->is_amx();
|
||||
|
||||
io_data_size = {brg0Prc.size(), brg1Prc.size()};
|
||||
if (brgemm_node->get_input_size() == 3)
|
||||
io_data_size.push_back(brgemm_node->get_input_element_type(2).size());
|
||||
io_data_size.push_back(brgemm_node->get_output_element_type(0).size());
|
||||
|
||||
m_with_comp = brgemm_node->is_with_compensations();
|
||||
m_with_scratch = brgemm_node->is_with_scratchpad();
|
||||
|
||||
m_N_blk = brg1Prc == Precision::FP32 ? m_N :
|
||||
brg1Prc == Precision::BF16 ? 32 : 64;
|
||||
m_N_blk = brgemm_node->get_n_block_size();
|
||||
m_K_blk = brgemm_node->get_k_block_size();
|
||||
m_N_tail = m_N % m_N_blk;
|
||||
m_K_blk = brgWithAMX ? brg0Prc == Precision::BF16 ? 32 : 64
|
||||
: m_K;
|
||||
m_K_tail = m_K % m_K_blk;
|
||||
|
||||
for (size_t k = 0; k < 2; k++) {
|
||||
for (size_t n = 0; n < 2; n++) {
|
||||
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(k, n)];
|
||||
m_N_blk_loop = m_N >= 2 * m_N_blk;
|
||||
m_K_blk_loop = m_K >= 3 * m_K_blk;
|
||||
OPENVINO_ASSERT((!brgemm_node->is_with_data_repacking()) || (!m_N_blk_loop && !m_K_blk_loop),
|
||||
"BrgemmEmitter doesn't support blocking by K, N dimensions when data repacking is needed!");
|
||||
|
||||
auto M_ = m_M;
|
||||
auto N_ = n ? m_N_tail : m_N - m_N_tail;
|
||||
auto K_ = k ? m_K_tail : m_K - m_K_tail;
|
||||
auto beta = k && m_brgCtxs0[getBrgIdx(0, n)].K != 0 ? 1.0f : 0.0f;
|
||||
auto N = [&](size_t n) {
|
||||
switch (n) {
|
||||
case 0: return m_N_blk;
|
||||
case 1: return m_N_tail;
|
||||
default: OPENVINO_THROW("BrgemmEmitter detected unsupported N value");
|
||||
}
|
||||
};
|
||||
auto K = [&](size_t k) {
|
||||
switch (k) {
|
||||
case 0: return m_K_blk;
|
||||
case 1: return m_K >= 2 * m_K_blk ? m_K_blk : 0;
|
||||
case 2: return m_K_tail;
|
||||
default: IE_THROW() << "BrgemmEmitter detected unsupported K value";
|
||||
}
|
||||
};
|
||||
|
||||
brgemmCtx.M = M_;
|
||||
brgemmCtx.N = N_;
|
||||
brgemmCtx.K = K_;
|
||||
bool has_K_kernel = false;
|
||||
for (size_t k = 0; k < BRGEMM_K_KERNEL_NUM; k++) {
|
||||
bool has_N_kernel = false;
|
||||
for (size_t n = 0; n < BRGEMM_N_KERNEL_NUM; n++) {
|
||||
const size_t kernel_idx = getBrgIdx(k, n);
|
||||
auto& brgemmCtx = m_brgCtxs[kernel_idx];
|
||||
|
||||
brgemmCtx.M = m_M;
|
||||
brgemmCtx.N = N(n);
|
||||
brgemmCtx.K = K(k);
|
||||
brgemmCtx.LDA = leading_dimensions[0];
|
||||
brgemmCtx.LDB = brgemm_node->is_with_data_repacking() ? rnd_up(m_N, m_N_blk) : leading_dimensions[1];
|
||||
brgemmCtx.LDB = brgemm_node->is_with_data_repacking() ? rnd_up(m_N, brgemm_copy->get_n_block_size()) : leading_dimensions[1];
|
||||
brgemmCtx.LDC = leading_dimensions[2];
|
||||
brgemmCtx.dt_in0 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc));
|
||||
brgemmCtx.dt_in1 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg1Prc));
|
||||
brgemmCtx.beta = beta;
|
||||
brgemmCtx.beta = has_K_kernel ? 1 : 0;
|
||||
|
||||
// don't create brgemm kernels for empty tiles
|
||||
if (M_ != 0 && K_ != 0 && N_ != 0) {
|
||||
initBrgemm(brgemmCtx, m_brgKernels0[getBrgIdx(k, n)], brgWithAMX);
|
||||
}
|
||||
if (brgemmCtx.N == 0 || brgemmCtx.N > m_N ||
|
||||
brgemmCtx.K == 0 || brgemmCtx.K > m_K)
|
||||
continue;
|
||||
|
||||
initBrgemm(brgemmCtx, m_brgKernels[kernel_idx], brgWithAMX);
|
||||
has_N_kernel = true;
|
||||
}
|
||||
if (has_N_kernel)
|
||||
has_K_kernel = true;
|
||||
}
|
||||
|
||||
m_load_offset_a = brgemm_node->get_offset_a();
|
||||
@@ -831,14 +858,39 @@ std::set<std::vector<element::Type>> BrgemmEmitter::get_supported_precisions(con
|
||||
}
|
||||
}
|
||||
|
||||
void BrgemmEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr<brgemm_kernel_t>& brgKernel, bool use_amx) const {
|
||||
void BrgemmEmitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
|
||||
std::set<size_t> unique_ids{in[0], in[1], out[0]};
|
||||
size_t unique_ids_count = 3;
|
||||
auto add_reg_to_unique_ids = [&](const size_t reg_number) {
|
||||
unique_ids.insert(reg_number);
|
||||
unique_ids_count++;
|
||||
};
|
||||
|
||||
if (m_N_blk_loop || m_K_blk_loop) {
|
||||
if (aux_gpr_idxs.size() < static_cast<size_t>(m_N_blk_loop) + static_cast<size_t>(m_K_blk_loop))
|
||||
IE_THROW() << "BRGEMM Emitter requires extra gpr which was not allocated";
|
||||
if (m_N_blk_loop)
|
||||
add_reg_to_unique_ids(aux_gpr_idxs[0]);
|
||||
if (m_K_blk_loop)
|
||||
add_reg_to_unique_ids(aux_gpr_idxs[m_N_blk_loop]);
|
||||
}
|
||||
if (m_with_scratch) {
|
||||
if (in.size() != 3)
|
||||
IE_THROW() << "BRGEMM Emitter expects 3 inputs if there are compensations/wsp";
|
||||
add_reg_to_unique_ids(in[2]);
|
||||
}
|
||||
if (unique_ids.size() != unique_ids_count) {
|
||||
IE_THROW() << "BRGEMM Emitter expects that all input/output registers are unique";
|
||||
}
|
||||
}
|
||||
|
||||
void BrgemmEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr<brgemm_kernel_t>& brgKernel, bool use_amx) {
|
||||
brgemm_t brgDesc;
|
||||
brgemm_strides_t strides {static_cast<dnnl_dim_t>(ctx.M * ctx.K), static_cast<dnnl_dim_t>(ctx.K * ctx.N)};
|
||||
const bool is_int8 = utils::one_of(ctx.dt_in0, data_type::u8, data_type::s8) && utils::one_of(ctx.dt_in1, data_type::u8, data_type::s8);
|
||||
auto isa = use_amx ? isa_undef
|
||||
: ctx.dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : (is_int8 ? avx512_core_vnni : avx512_core);
|
||||
auto status = brgemm_desc_init(&brgDesc, isa, brgemm_strd, ctx.dt_in0, ctx.dt_in1,
|
||||
false, false, brgemm_row_major, 1.f, ctx.beta, ctx.LDA, ctx.LDB, ctx.LDC, ctx.M, ctx.N, ctx.K, &strides);
|
||||
false, false, brgemm_row_major, 1.f, ctx.beta, ctx.LDA, ctx.LDB, ctx.LDC, ctx.M, ctx.N, ctx.K, nullptr);
|
||||
if (status != dnnl_success)
|
||||
IE_THROW() << "BrgemmEmitter cannot initialize brgemm descriptor due to invalid params";
|
||||
|
||||
@@ -856,48 +908,118 @@ void BrgemmEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr<brgemm_kernel_t>&
|
||||
brgKernel.reset(brgKernel_);
|
||||
}
|
||||
|
||||
size_t BrgemmEmitter::aux_gprs_count() const {
|
||||
return m_N_blk_loop + m_K_blk_loop;
|
||||
}
|
||||
|
||||
void BrgemmEmitter::emit_N_blocking_loops(size_t k_kernel_id,
|
||||
const Xbyak::Reg64& input_0, const Xbyak::Reg64& input_1,
|
||||
const Xbyak::Reg64& input_2, const Xbyak::Reg64& output_0,
|
||||
const Xbyak::Reg64& work_amount_N) const {
|
||||
// Blocked N loop
|
||||
size_t kernel_idx = getBrgIdx(k_kernel_id, 0);
|
||||
if (m_brgKernels[kernel_idx]) {
|
||||
const auto& brgemmCtx = m_brgCtxs[kernel_idx];
|
||||
Label N_loop_begin;
|
||||
if (m_N_blk_loop) {
|
||||
h->mov(work_amount_N, m_N);
|
||||
h->L(N_loop_begin);
|
||||
}
|
||||
|
||||
emit_brgemm_kernel_call(m_brgKernels[kernel_idx].get(), brgemmCtx, input_0, input_1, input_2, output_0);
|
||||
// We don't need to increment pointers if we cover full N dimension in one kernel call
|
||||
if (m_N_blk_loop || m_N_tail != 0) {
|
||||
h->add(output_0, brgemmCtx.N * io_data_size.back());
|
||||
h->add(input_1, brgemmCtx.N * io_data_size[1]);
|
||||
if (m_with_scratch && m_with_comp)
|
||||
h->add(input_2, brgemmCtx.N * io_data_size[2]);
|
||||
}
|
||||
|
||||
if (m_N_blk_loop) {
|
||||
h->sub(work_amount_N, brgemmCtx.N);
|
||||
h->cmp(work_amount_N, brgemmCtx.N);
|
||||
h->jge(N_loop_begin);
|
||||
}
|
||||
}
|
||||
// N loop tail
|
||||
kernel_idx = getBrgIdx(k_kernel_id, 1);
|
||||
if (m_brgKernels[kernel_idx])
|
||||
emit_brgemm_kernel_call(m_brgKernels[kernel_idx].get(), m_brgCtxs[kernel_idx], input_0, input_1, input_2, output_0);
|
||||
|
||||
if (m_N_blk_loop || m_N_tail != 0) {
|
||||
h->sub(input_1, (m_N - m_N_tail) * io_data_size[1]);
|
||||
h->sub(output_0, (m_N - m_N_tail) * io_data_size.back());
|
||||
if (m_with_scratch && m_with_comp)
|
||||
h->sub(input_2, (m_N - m_N_tail) * io_data_size[2]);
|
||||
}
|
||||
}
|
||||
|
||||
void BrgemmEmitter::emit_impl(const std::vector<size_t>& in,
|
||||
const std::vector<size_t>& out) const {
|
||||
validate_arguments(in, out);
|
||||
if (host_isa_ == cpu::x64::avx512_core) {
|
||||
Xbyak::Reg64 input_0(static_cast<int>(in[0]));
|
||||
Xbyak::Reg64 input_1(static_cast<int>(in[1]));
|
||||
Xbyak::Reg64 input_2(static_cast<int>(0)); // scratch. Default reg index is 0 if there isn't scratch
|
||||
if (m_with_scratch) {
|
||||
if (in.size() != 3) {
|
||||
IE_THROW() << "BRGEMM Emitter expects 3 inputs if there are compensations/wsp";
|
||||
}
|
||||
input_2 = Xbyak::Reg64(static_cast<int>(in[2]));
|
||||
}
|
||||
Xbyak::Reg64 output_0(static_cast<int>(out[0]));
|
||||
Xbyak::Reg64 work_amount_N(m_N_blk_loop ? static_cast<int>(aux_gpr_idxs[0]) : 0);
|
||||
Xbyak::Reg64 work_amount_K(m_K_blk_loop ? static_cast<int>(aux_gpr_idxs[m_N_blk_loop]) : 0);
|
||||
h->add(input_0, m_load_offset_a);
|
||||
h->add(input_1, m_load_offset_b);
|
||||
h->add(output_0, m_store_offset_c);
|
||||
if (m_with_scratch) {
|
||||
input_2 = Xbyak::Reg64(static_cast<int>(in[2]));
|
||||
h->add(input_2, m_load_offset_scratch);
|
||||
}
|
||||
|
||||
size_t brgIdx0 = getBrgIdx(0, 0);
|
||||
size_t K0_step0 = m_brgCtxs0[brgIdx0].K;
|
||||
size_t K0_step1 = m_brgCtxs0[brgIdx0].K * m_brgCtxs0[brgIdx0].LDB;
|
||||
size_t N0_step0 = m_brgCtxs0[brgIdx0].N * m_brg0VnniFactor;
|
||||
size_t N0_step1 = m_brgCtxs0[brgIdx0].N;
|
||||
for (size_t n = 0; n < 2; n++) {
|
||||
for (size_t k = 0; k < 2; k++) {
|
||||
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(k, n)];
|
||||
// fills kernel_idx with the first idx of non-empty K kernel or returns false
|
||||
auto get_K_kernel_idx = [&](size_t k_kernel_id, size_t& kernel_idx) {
|
||||
for (size_t n = 0; n < BRGEMM_N_KERNEL_NUM; n++) {
|
||||
const auto idx = getBrgIdx(k_kernel_id, n);
|
||||
if (m_brgKernels[idx]) {
|
||||
kernel_idx = idx;
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
// Blocked K loop
|
||||
const auto k_tail_id = BRGEMM_K_KERNEL_NUM - 1;
|
||||
size_t total_K_work_amount = m_K;
|
||||
size_t kernel_idx = SIZE_MAX;
|
||||
for (size_t k_blocked_id = 0; k_blocked_id < k_tail_id; k_blocked_id++) {
|
||||
if (get_K_kernel_idx(k_blocked_id, kernel_idx)) {
|
||||
const auto& brgemmCtx = m_brgCtxs[kernel_idx];
|
||||
Label K_loop_begin;
|
||||
// Note: we never emit loop for the first blocked kernel, since it always executed only once.
|
||||
// The purpose of the first blocked K kernel is to initializes output, because it has beta = 0
|
||||
if (k_blocked_id == 0) {
|
||||
total_K_work_amount -= brgemmCtx.K;
|
||||
} else if (m_K_blk_loop) {
|
||||
h->mov(work_amount_K, total_K_work_amount);
|
||||
h->L(K_loop_begin);
|
||||
}
|
||||
|
||||
if (brgemmCtx.K != 0 && brgemmCtx.N != 0) {
|
||||
const size_t in0_offset = m_load_offset_a + k * K0_step0 * io_data_size[0];
|
||||
const size_t in1_offset = m_load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1];
|
||||
const size_t in2_offset = m_load_offset_scratch + (m_with_comp ? n * N0_step1 * sizeof(int32_t) : 0);
|
||||
const size_t out0_offset = m_store_offset_c + n * N0_step1 * io_data_size[2];
|
||||
|
||||
emit_brgemm_kernel_call(m_brgKernels0[getBrgIdx(k, n)].get(),
|
||||
brgemmCtx,
|
||||
input_0,
|
||||
input_1,
|
||||
input_2,
|
||||
output_0,
|
||||
in0_offset,
|
||||
in1_offset,
|
||||
in2_offset,
|
||||
out0_offset);
|
||||
emit_N_blocking_loops(k_blocked_id, input_0, input_1, input_2, output_0, work_amount_N);
|
||||
h->add(input_0, brgemmCtx.K * io_data_size[0]);
|
||||
h->add(input_1, (brgemmCtx.K * brgemmCtx.LDB) * io_data_size[1]);
|
||||
if (m_K_blk_loop && k_blocked_id) {
|
||||
h->sub(work_amount_K, brgemmCtx.K);
|
||||
h->cmp(work_amount_K, brgemmCtx.K);
|
||||
h->jge(K_loop_begin);
|
||||
}
|
||||
}
|
||||
}
|
||||
// K loop tail
|
||||
if (get_K_kernel_idx(k_tail_id, kernel_idx)) {
|
||||
emit_N_blocking_loops(k_tail_id, input_0, input_1, input_2, output_0, work_amount_N);
|
||||
}
|
||||
|
||||
h->sub(input_0, m_load_offset_a + (m_K - m_K_tail) * io_data_size[0]);
|
||||
h->sub(input_1, m_load_offset_b + (m_K - m_K_tail) * m_brgCtxs[0].LDB * io_data_size[1]);
|
||||
if (m_with_scratch)
|
||||
h->sub(input_2, m_load_offset_scratch);
|
||||
h->sub(output_0, m_store_offset_c);
|
||||
} else {
|
||||
IE_THROW() << "BrgemmEmitter requires at least avx512_core instruction set";
|
||||
}
|
||||
@@ -1106,19 +1228,18 @@ BrgemmCopyBEmitter::BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, d
|
||||
m_N = *(transposed_shape.rbegin());
|
||||
m_K = *(transposed_shape.rbegin() + 1);
|
||||
|
||||
const bool isAMXSupported = mayiuse(avx512_core_amx);
|
||||
const auto use_amx = isAMXSupported && m_brgemm_prc_in0 != ov::element::f32 && (m_K % m_brgemmVNNIFactor == 0) && (m_N % m_brgemmVNNIFactor == 0);
|
||||
m_N_blk = brgemm_repack->get_n_block_size();
|
||||
m_K_blk = brgemm_repack->get_k_block_size();
|
||||
|
||||
m_N_blk = m_brgemm_prc_in1 == ov::element::f32 ? m_N :
|
||||
m_brgemm_prc_in1 == ov::element::bf16 ? 32 : 64;
|
||||
m_K_blk = use_amx ? m_brgemm_prc_in0 == ov::element::bf16 ? 32 : 64
|
||||
: m_K;
|
||||
m_N_tail = m_N % m_N_blk;
|
||||
m_K_tail = m_K % m_K_blk;
|
||||
m_LDB = m_brgemm_prc_in1 == ov::element::f32 ? leading_dimension : rnd_up(m_N, m_N_blk);
|
||||
|
||||
const auto dt_in0 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(m_brgemm_prc_in0)));
|
||||
const auto dt_in1 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(m_brgemm_prc_in1)));
|
||||
|
||||
const bool isAMXSupported = mayiuse(avx512_core_amx);
|
||||
const auto use_amx = isAMXSupported && m_brgemm_prc_in0 != ov::element::f32 && (m_K % m_brgemmVNNIFactor == 0) && (m_N % m_brgemmVNNIFactor == 0);
|
||||
init_brgemm_copy(m_kernel, leading_dimension, m_N_blk, m_N_tail, m_LDB, m_K - m_K_tail, use_amx, dt_in0, dt_in1);
|
||||
}
|
||||
|
||||
|
||||
@@ -338,37 +338,49 @@ public:
|
||||
|
||||
size_t get_inputs_num() const override { return m_with_scratch ? 3 : 2; }
|
||||
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ngraph::Node>& node = nullptr);
|
||||
size_t aux_gprs_count() const override;
|
||||
|
||||
private:
|
||||
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
|
||||
void emit_impl(const std::vector<size_t>& in,
|
||||
const std::vector<size_t>& out) const override;
|
||||
|
||||
std::vector<size_t> io_data_size {};
|
||||
struct brgemmCtx {
|
||||
brgemmCtx() : M(0), N(0), K(0),
|
||||
LDA(0), LDB(0), LDC(0),
|
||||
dt_in0(dnnl_f32), dt_in1(dnnl_f32),
|
||||
is_with_amx(false), is_with_comp(false), beta(0) {}
|
||||
size_t M, N, K, LDA, LDB, LDC;
|
||||
dnnl_data_type_t dt_in0, dt_in1;
|
||||
char palette[64];
|
||||
char palette[64] = {};
|
||||
bool is_with_amx;
|
||||
bool is_with_comp;
|
||||
float beta;
|
||||
};
|
||||
void initBrgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel, bool use_amx) const;
|
||||
size_t getBrgIdx(size_t kIdx, size_t nIdx) const;
|
||||
static void initBrgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel, bool use_amx);
|
||||
static size_t getBrgIdx(size_t kIdx, size_t nIdx);
|
||||
|
||||
void emit_brgemm_kernel_call(const dnnl::impl::cpu::x64::brgemm_kernel_t* brg_kernel, const brgemmCtx& ctx,
|
||||
Xbyak::Reg64 addr_A, Xbyak::Reg64 addr_B, Xbyak::Reg64 scratch, Xbyak::Reg64 addr_C,
|
||||
const size_t in0_kernel_offset, const size_t in1_kernel_offset,
|
||||
const size_t in2_kernel_offset, const size_t out0_kernel_offset) const;
|
||||
size_t in0_kernel_offset = 0, size_t in1_kernel_offset = 0,
|
||||
size_t in2_kernel_offset = 0, size_t out0_kernel_offset = 0) const;
|
||||
static void kernel_execute(const dnnl::impl::cpu::x64::brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C, void *scratch, int with_comp);
|
||||
void emit_N_blocking_loops(size_t k_kernel_id,
|
||||
const Xbyak::Reg64& input_0, const Xbyak::Reg64& input_1,
|
||||
const Xbyak::Reg64& input_2, const Xbyak::Reg64& output_0,
|
||||
const Xbyak::Reg64& work_amount_N) const;
|
||||
|
||||
static constexpr size_t BRGEMM_KERNELS_NUM = 8;
|
||||
brgemmCtx m_brgCtxs0[BRGEMM_KERNELS_NUM];
|
||||
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> m_brgKernels0[BRGEMM_KERNELS_NUM];
|
||||
// Note: K dimension is covered by TWO blocked kernels (with beta = 0 and 1) + 1 for tail
|
||||
static constexpr size_t BRGEMM_K_KERNEL_NUM = 3;
|
||||
static constexpr size_t BRGEMM_N_KERNEL_NUM = 2;
|
||||
std::array<brgemmCtx, BRGEMM_K_KERNEL_NUM * BRGEMM_N_KERNEL_NUM> m_brgCtxs;
|
||||
std::array<std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>, BRGEMM_K_KERNEL_NUM * BRGEMM_N_KERNEL_NUM> m_brgKernels;
|
||||
|
||||
size_t m_M;
|
||||
size_t m_K, m_K_blk, m_K_tail;
|
||||
size_t m_N, m_N_blk, m_N_tail;
|
||||
size_t m_brg0VnniFactor;
|
||||
bool m_N_blk_loop = false;
|
||||
bool m_K_blk_loop = false;
|
||||
|
||||
bool m_with_scratch = false;
|
||||
bool m_with_comp = false;
|
||||
@@ -377,6 +389,8 @@ private:
|
||||
size_t m_load_offset_b = 0lu;
|
||||
size_t m_load_offset_scratch = 0lu;
|
||||
size_t m_store_offset_c = 0lu;
|
||||
|
||||
std::vector<size_t> io_data_size {};
|
||||
};
|
||||
|
||||
class BrgemmCopyBEmitter : public jit_emitter {
|
||||
|
||||
@@ -30,6 +30,7 @@
|
||||
#include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp"
|
||||
#include "transformations/snippets/x64/pass/remove_converts.hpp"
|
||||
#include "transformations/snippets/x64/pass/enforce_precision.hpp"
|
||||
#include "transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/convert_to_swish_cpu.hpp"
|
||||
#include "transformations/defs.hpp"
|
||||
|
||||
@@ -560,6 +561,7 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) {
|
||||
|
||||
ov::pass::Manager post_dialect;
|
||||
CPU_REGISTER_PASS_X64(post_dialect, ov::intel_cpu::pass::BrgemmToBrgemmCPU);
|
||||
CPU_REGISTER_PASS_X64(post_dialect, ov::intel_cpu::pass::SetBrgemmCPUBlockingParams);
|
||||
|
||||
ov::pass::Manager post_precision;
|
||||
CPU_REGISTER_PASS_X64(post_precision, ov::intel_cpu::pass::RemoveConverts);
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "snippets/op/buffer.hpp"
|
||||
|
||||
#include "brgemm_copy_b.hpp"
|
||||
|
||||
@@ -12,27 +13,32 @@
|
||||
using namespace ov;
|
||||
|
||||
intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output<Node>& x, const element::Type src_type, const Type type,
|
||||
const size_t offset_in, const size_t offset_out0, const size_t offset_out1, std::vector<size_t> layout_input)
|
||||
: snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1), m_type(type), m_src_type(src_type) {
|
||||
const size_t offset_in, const size_t offset_out0, const size_t offset_out1,
|
||||
std::vector<size_t> layout_input, const size_t blk_size_k, const size_t blk_size_n)
|
||||
: snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1),
|
||||
m_type(type), m_src_type(src_type) {
|
||||
set_output_size(type == Type::WithCompensations ? 2 : 1);
|
||||
set_input_port_descriptor({0, offset_in}, 0);
|
||||
set_output_port_descriptor({0, offset_out0}, 0);
|
||||
if (is_with_compensations()) {
|
||||
set_output_port_descriptor({0, offset_out1}, 1);
|
||||
}
|
||||
compute_block_size_values(blk_size_k, blk_size_n);
|
||||
custom_constructor_validate_and_infer_types(std::move(layout_input));
|
||||
}
|
||||
|
||||
intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output<Node>& x, const element::Type src_type, const Type type,
|
||||
const PortDescriptor& desc_in0, const PortDescriptor& desc_out0, const PortDescriptor& desc_out1,
|
||||
std::vector<size_t> layout_input)
|
||||
: snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1), m_type(type), m_src_type(src_type) {
|
||||
std::vector<size_t> layout_input, const size_t blk_size_k, const size_t blk_size_n)
|
||||
: snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1),
|
||||
m_type(type), m_src_type(src_type) {
|
||||
set_output_size(type == Type::WithCompensations ? 2 : 1);
|
||||
set_input_port_descriptor(desc_in0, 0);
|
||||
set_output_port_descriptor(desc_out0, 0);
|
||||
if (is_with_compensations()) {
|
||||
set_output_port_descriptor(desc_out1, 1);
|
||||
}
|
||||
compute_block_size_values(blk_size_k, blk_size_n);
|
||||
custom_constructor_validate_and_infer_types(std::move(layout_input));
|
||||
}
|
||||
|
||||
@@ -75,16 +81,21 @@ void intel_cpu::BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov::
|
||||
const auto shape = pshape.get_shape();
|
||||
const auto N = *shape.rbegin();
|
||||
const auto K = *(shape.rbegin() + 1);
|
||||
const auto N_blk = element_type == element::bf16 ? 32 : 64;
|
||||
const auto brgemmVNNIFactor = 4 / m_src_type.size();
|
||||
|
||||
set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, brgemmVNNIFactor)),
|
||||
ov::Dimension(rnd_up(N, N_blk))});
|
||||
ov::Dimension(rnd_up(N, m_N_blk))});
|
||||
if (is_with_compensations()) {
|
||||
set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, N_blk))});
|
||||
set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, m_N_blk))});
|
||||
}
|
||||
}
|
||||
|
||||
void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n) {
|
||||
const auto input_shape = snippets::utils::get_port_planar_shape(input(0)).get_shape();
|
||||
m_K_blk = blk_size_k != 0 ? blk_size_k : *(input_shape.rbegin() + 1);
|
||||
m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape.rbegin();
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
@@ -92,7 +103,8 @@ std::shared_ptr<Node> intel_cpu::BrgemmCopyB::clone_with_new_inputs(const Output
|
||||
get_input_port_descriptor(0),
|
||||
get_output_port_descriptor(0),
|
||||
is_with_compensations() ? get_output_port_descriptor(1) : PortDescriptor{},
|
||||
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout());
|
||||
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(),
|
||||
m_K_blk, m_N_blk);
|
||||
}
|
||||
|
||||
size_t intel_cpu::BrgemmCopyB::get_offset_compensations() const {
|
||||
|
||||
@@ -27,16 +27,21 @@ public:
|
||||
|
||||
BrgemmCopyB(const Output<Node>& x, const element::Type src_type, const Type type = Type::OnlyRepacking,
|
||||
const size_t offset_in = 0lu, const size_t offset_out0 = 0lu, const size_t offset_out1 = 0lu,
|
||||
std::vector<size_t> layout_input = {});
|
||||
std::vector<size_t> layout_input = {}, const size_t blk_size_k = 0, const size_t blk_size_n = 0);
|
||||
BrgemmCopyB(const Output<Node>& x, const element::Type src_type, const Type type = Type::OnlyRepacking,
|
||||
const PortDescriptor& desc_in0 = {}, const PortDescriptor& desc_out0 = {}, const PortDescriptor& desc_out1 = {},
|
||||
std::vector<size_t> layout_input = {});
|
||||
std::vector<size_t> layout_input = {}, const size_t blk_size_k = 0, const size_t blk_size_n = 0);
|
||||
BrgemmCopyB() = default;
|
||||
|
||||
size_t get_offset_in() const { return get_input_offset(0); }
|
||||
size_t get_offset_out() const { return get_output_offset(0); }
|
||||
size_t get_offset_compensations() const;
|
||||
|
||||
size_t get_k_block_size() const { return m_K_blk; }
|
||||
size_t get_n_block_size() const { return m_N_blk; }
|
||||
void set_k_block_size(size_t block_size) { m_K_blk = block_size; }
|
||||
void set_n_block_size(size_t block_size) { m_N_blk = block_size; }
|
||||
|
||||
Type get_type() const { return m_type; }
|
||||
element::Type get_src_element_type() const { return m_src_type; }
|
||||
bool is_with_compensations() const { return m_type == Type::WithCompensations; }
|
||||
@@ -49,9 +54,13 @@ public:
|
||||
private:
|
||||
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_input = {});
|
||||
void validate(const ov::PartialShape& pshape, const ov::element::Type& element_type);
|
||||
void compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n);
|
||||
|
||||
Type m_type = Type::OnlyRepacking;
|
||||
element::Type m_src_type = ov::element::undefined; // src element type of the corresponding BRGEMM
|
||||
|
||||
size_t m_K_blk = 0;
|
||||
size_t m_N_blk = 0;
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
||||
@@ -14,7 +14,8 @@ namespace intel_cpu {
|
||||
|
||||
BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type type,
|
||||
const size_t offset_a, const size_t offset_b, const size_t offset_c,
|
||||
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c)
|
||||
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c,
|
||||
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n)
|
||||
: Brgemm(), m_type(type) {
|
||||
// We call default ctor of Brgemm class to avoid incorrect shape infer in constructor_validate_and_type_infer() call
|
||||
set_arguments({A, B});
|
||||
@@ -23,12 +24,14 @@ BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type ty
|
||||
set_input_port_descriptor({0, offset_a}, 0);
|
||||
set_input_port_descriptor({0, offset_b}, 1);
|
||||
set_output_port_descriptor({0, offset_c}, 0);
|
||||
compute_block_size_values(blk_size_m, blk_size_k, blk_size_n);
|
||||
custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c));
|
||||
}
|
||||
|
||||
BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<Node>& scratch, const Type type,
|
||||
const size_t offset_a, const size_t offset_b, const size_t offset_scratch, const size_t offset_c,
|
||||
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c)
|
||||
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c,
|
||||
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n)
|
||||
: Brgemm(), m_type(type) {
|
||||
set_arguments({A, B, scratch});
|
||||
set_output_size(1);
|
||||
@@ -37,28 +40,33 @@ BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<
|
||||
set_input_port_descriptor({0, offset_b}, 1);
|
||||
set_output_port_descriptor({0, offset_c}, 0);
|
||||
set_input_port_descriptor({0, offset_scratch}, 2);
|
||||
compute_block_size_values(blk_size_m, blk_size_k, blk_size_n);
|
||||
custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c));
|
||||
}
|
||||
|
||||
BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type type,
|
||||
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c,
|
||||
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c)
|
||||
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c,
|
||||
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n)
|
||||
: Brgemm(), m_type(type) {
|
||||
set_arguments({A, B});
|
||||
set_output_size(1);
|
||||
m_input_ports = {{0, desc_a}, {1, desc_b}};
|
||||
m_output_ports = {{0, desc_c}};
|
||||
compute_block_size_values(blk_size_m, blk_size_k, blk_size_n);
|
||||
custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c));
|
||||
}
|
||||
|
||||
BrgemmCPU::BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<Node>& scratch, const Type type,
|
||||
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_scratch, const PortDescriptor& desc_c,
|
||||
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c)
|
||||
std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c,
|
||||
const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n)
|
||||
: Brgemm(), m_type(type) {
|
||||
set_arguments({A, B, scratch});
|
||||
set_output_size(1);
|
||||
m_input_ports = {{0, desc_a}, {1, desc_b}, {2, desc_scratch}};
|
||||
m_output_ports = {{0, desc_c}};
|
||||
compute_block_size_values(blk_size_m, blk_size_k, blk_size_n);
|
||||
custom_constructor_validate_and_infer_types(std::move(layout_a), std::move(layout_b), std::move(layout_c));
|
||||
}
|
||||
|
||||
@@ -76,10 +84,18 @@ void BrgemmCPU::custom_constructor_validate_and_infer_types(std::vector<size_t>
|
||||
auto output_shape = get_output_partial_shape(planar_input_shapes);
|
||||
set_output_type(0, get_output_type(), snippets::utils::get_reordered_planar_shape(output_shape, layout_c));
|
||||
|
||||
//Additional check for 3rd input
|
||||
// Additional check for 3rd input
|
||||
validate_with_scratchpad(planar_input_shapes[1].get_shape());
|
||||
}
|
||||
|
||||
void BrgemmCPU::compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n) {
|
||||
const auto input_shape_0 = snippets::utils::get_port_planar_shape(input(0)).get_shape();
|
||||
const auto input_shape_1 = snippets::utils::get_port_planar_shape(input(1)).get_shape();
|
||||
m_M_blk = blk_size_m != 0 ? blk_size_m : *(input_shape_0.rbegin() + 1);
|
||||
m_K_blk = blk_size_k != 0 ? blk_size_k : *input_shape_0.rbegin();
|
||||
m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape_1.rbegin();
|
||||
}
|
||||
|
||||
void BrgemmCPU::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(BrgemmCPU_validate_and_infer_types);
|
||||
validate_inputs();
|
||||
@@ -89,28 +105,29 @@ void BrgemmCPU::validate_and_infer_types() {
|
||||
auto output_shape = get_output_partial_shape(planar_input_shapes);
|
||||
set_output_type(0, get_output_type(), get_planar_output_shape(output_shape));
|
||||
|
||||
//Additional check for 3rd input
|
||||
// Additional check for 3rd input
|
||||
validate_with_scratchpad(planar_input_shapes[1].get_shape());
|
||||
}
|
||||
|
||||
void BrgemmCPU::validate_with_scratchpad(const ov::Shape& shape_b) const {
|
||||
//Additional check for 3rd input
|
||||
// Additional check for 3rd input
|
||||
if (one_of(m_type, Type::WithCompensations, Type::AMX)) {
|
||||
const auto shape = get_input_partial_shape(2);
|
||||
NGRAPH_CHECK(shape.is_static(), "BRGEMM Scratch must have static shape");
|
||||
const auto& pshape = get_input_partial_shape(2);
|
||||
NGRAPH_CHECK(pshape.is_static(), "BRGEMM Scratch must have static shape");
|
||||
const auto shape = pshape.to_shape();
|
||||
const auto type = get_input_element_type(2);
|
||||
if (is_with_compensations()) {
|
||||
const auto element_type_b = get_input_element_type(0);
|
||||
const auto N = *shape_b.rbegin();
|
||||
const auto N_blk = element_type_b == element::f32 ? N :
|
||||
element_type_b == element::bf16 ? 32 : 64;
|
||||
const auto expected_shape = ov::Shape{rnd_up(N, N_blk)};
|
||||
const auto expected_type = ov::element::f32;
|
||||
NGRAPH_CHECK(expected_shape == shape.get_shape() && expected_type == type,
|
||||
"BRGEMM Scratch with compensations must have shape {rnd_up(N, N_blk)} and FP32 element type");
|
||||
NGRAPH_CHECK(expected_type == type, "BRGEMM Scratch with compensations must have FP32 element type");
|
||||
const auto N = *shape_b.rbegin();
|
||||
// If N block size is not set, there is no meaning in validating the scratchpad shape
|
||||
if (m_N_blk != N) {
|
||||
const auto expected_shape = ov::Shape{rnd_up(N, m_N_blk)};
|
||||
NGRAPH_CHECK(expected_shape == shape, "BRGEMM Scratch with compensations must have shape {rnd_up(N, m_N_blk)}");
|
||||
}
|
||||
} else {
|
||||
NGRAPH_CHECK(ngraph::shape_size(shape.get_shape()) == SCRATCH_BYTE_SIZE && type == ov::element::u8,
|
||||
"BRGEMM Scratch for space workplace must be static, have U8 element type and size is equal to " + std::to_string(SCRATCH_BYTE_SIZE));
|
||||
NGRAPH_CHECK(ov::shape_size(shape) == SCRATCH_BYTE_SIZE && type == ov::element::u8,
|
||||
"BRGEMM Scratch for space workplace must be static, have U8 element type and size equal to " + std::to_string(SCRATCH_BYTE_SIZE));
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -133,19 +150,27 @@ std::shared_ptr<Node> BrgemmCPU::clone_with_new_inputs(const OutputVector& new_a
|
||||
get_input_port_descriptor(0), get_input_port_descriptor(1), get_output_port_descriptor(0),
|
||||
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(),
|
||||
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(),
|
||||
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout());
|
||||
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout(),
|
||||
m_M_blk, m_K_blk, m_N_blk);
|
||||
}
|
||||
return std::make_shared<BrgemmCPU>(new_args.at(0), new_args.at(1), new_args.at(2), m_type,
|
||||
get_input_port_descriptor(0), get_input_port_descriptor(1), get_input_port_descriptor(2), get_output_port_descriptor(0),
|
||||
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0))->get_layout(),
|
||||
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(1))->get_layout(),
|
||||
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout());
|
||||
snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(output(0))->get_layout(),
|
||||
m_M_blk, m_K_blk, m_N_blk);
|
||||
}
|
||||
|
||||
std::shared_ptr<BrgemmCopyB> BrgemmCPU::get_brgemm_copy() const {
|
||||
OPENVINO_ASSERT(one_of(m_type, Type::WithDataRepacking, Type::WithCompensations, Type::AMX), "Brgemm doesn't need BrgemmCopyB");
|
||||
if (const auto buffer = ov::as_type_ptr<snippets::op::Buffer>(get_input_node_shared_ptr(1))) {
|
||||
return ov::as_type_ptr<BrgemmCopyB>(buffer->get_input_node_shared_ptr(0));
|
||||
auto b_input_node = get_input_node_shared_ptr(1);
|
||||
if (const auto brgemm_copy_b = ov::as_type_ptr<BrgemmCopyB>(b_input_node)) {
|
||||
return brgemm_copy_b;
|
||||
}
|
||||
if (ov::is_type<snippets::op::Buffer>(b_input_node)) {
|
||||
if (const auto brgemm_copy_b = ov::as_type_ptr<BrgemmCopyB>(b_input_node->get_input_node_shared_ptr(0))) {
|
||||
return brgemm_copy_b;
|
||||
}
|
||||
}
|
||||
OPENVINO_THROW("BrgemmCopyB hasn't been found!");
|
||||
}
|
||||
|
||||
@@ -31,22 +31,34 @@ public:
|
||||
|
||||
BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type type,
|
||||
const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_c = 0,
|
||||
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {});
|
||||
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
|
||||
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0);
|
||||
BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<Node>& scratch, const Type type,
|
||||
const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_scratch = 0, const size_t offset_c = 0,
|
||||
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {});
|
||||
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
|
||||
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0);
|
||||
BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Type type,
|
||||
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_c,
|
||||
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {});
|
||||
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
|
||||
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0);
|
||||
BrgemmCPU(const Output<Node>& A, const Output<Node>& B, const Output<Node>& scratch, const Type type,
|
||||
const PortDescriptor& desc_a, const PortDescriptor& desc_b, const PortDescriptor& desc_scratch, const PortDescriptor& desc_c,
|
||||
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {});
|
||||
std::vector<size_t> layout_a = {}, std::vector<size_t> layout_b = {}, std::vector<size_t> layout_c = {},
|
||||
const size_t blk_size_m = 0, const size_t blk_size_k = 0, const size_t blk_size_n = 0);
|
||||
BrgemmCPU() = default;
|
||||
|
||||
void validate_and_infer_types() override;
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
|
||||
Type get_type() const { return m_type; }
|
||||
size_t get_m_block_size() const { return m_M_blk; }
|
||||
size_t get_k_block_size() const { return m_K_blk; }
|
||||
size_t get_n_block_size() const { return m_N_blk; }
|
||||
|
||||
void set_m_block_size(size_t block_size) { m_M_blk = block_size; }
|
||||
void set_k_block_size(size_t block_size) { m_K_blk = block_size; }
|
||||
void set_n_block_size(size_t block_size) { m_N_blk = block_size; }
|
||||
|
||||
bool is_with_compensations() const { return m_type == Type::WithCompensations; }
|
||||
bool is_with_data_repacking() const { return m_type != Type::Floating; }
|
||||
bool is_amx() const { return m_type == Type::AMX; }
|
||||
@@ -59,10 +71,14 @@ public:
|
||||
|
||||
private:
|
||||
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c);
|
||||
void compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n);
|
||||
void validate_with_scratchpad(const ov::Shape& shape_b) const;
|
||||
void validate_inputs() const;
|
||||
|
||||
Type m_type = Type::Floating;
|
||||
size_t m_M_blk = 0;
|
||||
size_t m_K_blk = 0;
|
||||
size_t m_N_blk = 0;
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
||||
@@ -78,7 +78,7 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
|
||||
const auto offset_b = brgemm->get_offset_b();
|
||||
const auto offset_c = brgemm->get_offset_c();
|
||||
|
||||
std::shared_ptr<ov::Node> brgemm_cpu = nullptr;
|
||||
std::shared_ptr<BrgemmCPU> brgemm_cpu = nullptr;
|
||||
std::shared_ptr<BrgemmCopyB> brgemm_repacking = nullptr;
|
||||
if (element_type_a == ov::element::f32) {
|
||||
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), brgemm->input_value(1), BrgemmCPU::Type::Floating,
|
||||
@@ -88,30 +88,24 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
|
||||
const auto copy_b_type = with_comp ? BrgemmCopyB::WithCompensations : BrgemmCopyB::OnlyRepacking;
|
||||
brgemm_repacking = std::make_shared<BrgemmCopyB>(brgemm->input_value(1), element_type_a, copy_b_type, offset_b, 0, 0,
|
||||
brgemm_in1_desc->get_layout());
|
||||
const auto buffer = std::make_shared<snippets::op::Buffer>(brgemm_repacking->output(0));
|
||||
set_port_desc(brgemm_repacking->input(0), brgemm_in1_desc->get_shape(), brgemm_in1_desc->get_subtensor(), brgemm_in1_desc->get_layout());
|
||||
set_full_port_desc(brgemm_repacking->output(0));
|
||||
set_full_port_desc(buffer->input(0));
|
||||
set_full_port_desc(buffer->output(0));
|
||||
|
||||
if (with_amx) {
|
||||
const auto scratch = std::make_shared<snippets::op::Buffer>(ov::Shape{BrgemmCPU::SCRATCH_BYTE_SIZE});
|
||||
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), buffer, scratch, BrgemmCPU::Type::AMX,
|
||||
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), brgemm_repacking->output(0), scratch, BrgemmCPU::Type::AMX,
|
||||
offset_a, offset_b, 0, offset_c,
|
||||
brgemm_in0_desc->get_layout(), std::vector<size_t>{}, brgemm_out_desc->get_layout());
|
||||
set_full_port_desc(scratch->output(0));
|
||||
set_full_port_desc(brgemm_cpu->input(2));
|
||||
} else if (with_comp) {
|
||||
const auto scratch = std::make_shared<snippets::op::Buffer>(brgemm_repacking->output(1));
|
||||
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), buffer, scratch, BrgemmCPU::Type::WithCompensations,
|
||||
offset_a, offset_b, 0, offset_c,
|
||||
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), brgemm_repacking->output(0), brgemm_repacking->output(1),
|
||||
BrgemmCPU::Type::WithCompensations, offset_a, offset_b, 0, offset_c,
|
||||
brgemm_in0_desc->get_layout(), std::vector<size_t>{}, brgemm_out_desc->get_layout());
|
||||
set_full_port_desc(brgemm_repacking->output(1));
|
||||
set_full_port_desc(scratch->input(0));
|
||||
set_full_port_desc(scratch->output(0));
|
||||
set_full_port_desc(brgemm_cpu->input(2));
|
||||
} else if (one_of(element_type_a, ov::element::u8, ov::element::bf16)) {
|
||||
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), buffer, BrgemmCPU::Type::WithDataRepacking,
|
||||
brgemm_cpu = std::make_shared<BrgemmCPU>(brgemm->input_value(0), brgemm_repacking->output(0), BrgemmCPU::Type::WithDataRepacking,
|
||||
offset_a, offset_b, offset_c,
|
||||
brgemm_in0_desc->get_layout(), std::vector<size_t>{}, brgemm_out_desc->get_layout());
|
||||
} else {
|
||||
|
||||
@@ -27,12 +27,9 @@ bool BrgemmBlocking::run(snippets::lowered::LinearIR& linear_ir) {
|
||||
if (linear_ir.empty())
|
||||
return false;
|
||||
|
||||
// Ticket: 113745
|
||||
// TODO: make the block size configurable
|
||||
const auto block_size = 32;
|
||||
const auto dim_idx = 1;
|
||||
|
||||
const auto& loop_manager = linear_ir.get_loop_manager();
|
||||
const auto dim_idx = 1;
|
||||
|
||||
auto blocking_loop_exists = [&](const ov::snippets::lowered::ExpressionPtr& expr,
|
||||
const std::shared_ptr<ov::intel_cpu::BrgemmCPU>& brgemm) {
|
||||
@@ -61,6 +58,7 @@ bool BrgemmBlocking::run(snippets::lowered::LinearIR& linear_ir) {
|
||||
const auto& dim = *(input_layout_0.rbegin() + dim_idx);
|
||||
const auto& m = input_shape_0[dim];
|
||||
|
||||
const auto block_size = brgemm->get_m_block_size();
|
||||
brgemm->set_input_count(block_size);
|
||||
|
||||
const auto work_amount = m;
|
||||
|
||||
@@ -0,0 +1,103 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
|
||||
#include "set_brgemm_cpu_blocking_params.hpp"
|
||||
|
||||
#include "snippets/utils.hpp"
|
||||
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
|
||||
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
|
||||
|
||||
#include "openvino/core/rt_info.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
|
||||
#include <cpu/x64/cpu_isa_traits.hpp>
|
||||
|
||||
#include "cpu_shape.h"
|
||||
#include "utils/general_utils.h"
|
||||
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
using namespace snippets::lowered;
|
||||
namespace {
|
||||
template <typename T>
|
||||
void change_desc_shape(const T& port) {
|
||||
const auto desc = PortDescriptorUtils::get_port_descriptor_ptr(port);
|
||||
const auto& shape = port.get_shape();
|
||||
if (desc->get_shape() != shape) {
|
||||
desc->set_shape(shape);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() {
|
||||
MATCHER_SCOPE(SetBrgemmCPUBlockingParams);
|
||||
|
||||
auto m_brgemm = ov::pass::pattern::wrap_type<BrgemmCPU>();
|
||||
|
||||
auto callback = [=](ov::pass::pattern::Matcher& m) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::SetBrgemmCPUBlockingParams")
|
||||
const auto node = m.get_match_root();
|
||||
auto brgemm = ov::as_type_ptr<BrgemmCPU>(node);
|
||||
if (brgemm->is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto dimsMatMulIn0 = snippets::utils::get_port_planar_shape(brgemm->input_value(0)).get_shape();
|
||||
const auto dimsMatMulIn1 = snippets::utils::get_port_planar_shape(brgemm->input_value(1)).get_shape();
|
||||
const auto K = *dimsMatMulIn0.rbegin();
|
||||
const auto N = *dimsMatMulIn1.rbegin();
|
||||
|
||||
const auto& input_1_precision = brgemm->get_input_element_type(1);
|
||||
|
||||
// Ticket: 113745
|
||||
// TODO: extend block size selection heuristics
|
||||
const size_t brgemm_block_size_m = 32;
|
||||
const size_t brgemm_block_size_k = [&]() {
|
||||
if (input_1_precision != ov::element::f32)
|
||||
return K;
|
||||
return K > 1024 ? 1024 : K > 512 ? 512 : K;
|
||||
}();
|
||||
const size_t brgemm_block_size_n = input_1_precision != ov::element::f32 ? N : 64;
|
||||
|
||||
brgemm->set_m_block_size(brgemm_block_size_m);
|
||||
brgemm->set_k_block_size(brgemm_block_size_k);
|
||||
brgemm->set_n_block_size(brgemm_block_size_n);
|
||||
|
||||
if (brgemm->is_with_data_repacking()) {
|
||||
const auto brgemm_copy_b = brgemm->get_brgemm_copy();
|
||||
|
||||
const bool isAMXSupported = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx);
|
||||
const auto precision = brgemm_copy_b->get_src_element_type();
|
||||
const auto brgemmVNNIFactor = 4 / precision.size();
|
||||
const bool use_amx = isAMXSupported && precision != ov::element::f32 && (K % brgemmVNNIFactor == 0) && (N % brgemmVNNIFactor == 0);
|
||||
|
||||
const size_t copy_b_block_size_k = use_amx ? brgemm_block_size_k : K;
|
||||
const size_t copy_b_block_size_n = 64;
|
||||
|
||||
brgemm_copy_b->set_k_block_size(copy_b_block_size_k);
|
||||
brgemm_copy_b->set_n_block_size(copy_b_block_size_n);
|
||||
// since N block size affects output shapes, the validation must be called explicitly right after the block size changing
|
||||
brgemm_copy_b->validate_and_infer_types();
|
||||
change_desc_shape(brgemm_copy_b->output(0));
|
||||
if (brgemm_copy_b->is_with_compensations())
|
||||
change_desc_shape(brgemm_copy_b->output(1));
|
||||
}
|
||||
|
||||
brgemm->validate_and_infer_types();
|
||||
change_desc_shape(brgemm->input(1));
|
||||
if (brgemm->is_with_scratchpad())
|
||||
change_desc_shape(brgemm->input(2));
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(m_brgemm, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
@@ -0,0 +1,27 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "openvino/pass/graph_rewrite.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace pass {
|
||||
|
||||
/**
|
||||
* @interface SetBrgemmCPUBlockingParams
|
||||
* @brief The pass selects optimal M, K and N blocking parameters for BrgemmCPU and sets them to the node.
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class SetBrgemmCPUBlockingParams: public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("SetBrgemmCPUBlockingParams", "0");
|
||||
SetBrgemmCPUBlockingParams();
|
||||
};
|
||||
|
||||
|
||||
} // namespace pass
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
@@ -680,22 +680,14 @@ void Transformations::MainSnippets(void) {
|
||||
const auto shape = pshape.get_shape();
|
||||
const auto parallel_work_amount =
|
||||
std::accumulate(shape.rbegin() + 2, shape.rend(), 1, std::multiplies<size_t>());
|
||||
const auto kernel_buffer_size =
|
||||
std::accumulate(shape.rbegin(), shape.rbegin() + 2, 1, std::multiplies<size_t>()) *
|
||||
n->get_output_element_type(0).size();
|
||||
// Heuristic values:
|
||||
// parallelism work amount - not enough work amount for parallelism
|
||||
// kernel work amount - large shape for kernel execution, not cache-local
|
||||
// TODO: The heuristics will be removed after
|
||||
// - loop blocking support on code generation level
|
||||
// - parallelism support on JIT level
|
||||
// TODO: The heuristic will be removed after parallelism support on JIT level
|
||||
const auto needed_num_of_threads = 12lu;
|
||||
const auto l2_cache_size = dnnl::utils::get_cache_size(2, true);
|
||||
const auto is_unsupported_parallel_work_amount =
|
||||
parallel_get_num_threads() / 2 > parallel_work_amount &&
|
||||
static_cast<size_t>(parallel_work_amount) < needed_num_of_threads;
|
||||
const auto is_unsupported_kernel_work_amount = kernel_buffer_size > l2_cache_size;
|
||||
return is_unsupported_parallel_work_amount || is_unsupported_kernel_work_amount;
|
||||
return is_unsupported_parallel_work_amount;
|
||||
},
|
||||
snippets::pass::TokenizeMHASnippets);
|
||||
CPU_SET_CALLBACK_X64(snippetsManager,
|
||||
|
||||
@@ -17,7 +17,9 @@ std::vector<std::vector<ov::PartialShape>> input_shapes{
|
||||
{{3, 1, 32, 14}, {1, 2, 14, 32}},
|
||||
{{1, 2, 37, 23}, {2, 1, 23, 37}},
|
||||
{{1, 1, 37, 23}, {1, 2, 23, 33}},
|
||||
{{1, 16, 384, 64}, {1, 16, 64, 384}}
|
||||
{{1, 1, 32, 23}, {1, 1, 23, 68}},
|
||||
{{1, 16, 384, 64}, {1, 16, 64, 384}},
|
||||
{{1, 1, 100, 700}, {1, 1, 700, 100}},
|
||||
};
|
||||
|
||||
static inline std::vector<std::vector<element::Type>> quantized_precisions() {
|
||||
|
||||
Reference in New Issue
Block a user