[Snippets] MatMul: blocking by M dimension at the LIR stage (#18169)

* [Snippets] MatMul: blocking by M dim at LIR level

* Alexandra's comments applied

* Ivan's comments applied

* Fix warning
This commit is contained in:
Vladislav Golubev 2023-06-28 19:31:44 +02:00 committed by GitHub
parent 3bc8065ca3
commit 18e737493c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
17 changed files with 434 additions and 87 deletions

View File

@ -76,11 +76,19 @@ public:
LinearIR::constExprIt loop_end_pos,
size_t loop_depth, size_t vector_size);
// Return Loop ID
template <typename T>
size_t mark_loop(LinearIR::constExprIt loop_begin_pos,
LinearIR::constExprIt loop_end_pos,
size_t work_amount, size_t work_amount_increment, size_t dim_idx,
const std::vector<ExpressionPort>& entries,
const std::vector<ExpressionPort>& exits);
const std::vector<T>& entries,
const std::vector<T>& exits) {
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, work_amount_increment, dim_idx, entries, exits);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
insert_loop_id(*expr_it, loop_id);
}
return loop_id;
}
void fuse_loops(const LinearIR& linear_ir, size_t loop_id_upper, size_t loop_id_lower, bool fuse_into_upper = true);
void fuse_loops(LinearIR::constExprIt loop_begin_target, LinearIR::constExprIt loop_end_target,
@ -123,6 +131,8 @@ public:
LinearIR::constExprIt& loop_end_pos,
size_t loop_id, bool loop_ops_inserted = false);
LoopPort get_loop_port_by_expr_port(const ExpressionPort& expr_port, const size_t loop_id);
private:
static void get_io_loop_ports(LinearIR::constExprIt loop_begin_pos,
LinearIR::constExprIt loop_end_pos,

View File

@ -42,6 +42,12 @@ public:
FuseLoops();
bool run(LinearIR& linear_ir) override;
// This method checks that all ports which connect lower and upper loops are incremented.
// This helps to avoid fusing for the ports with incompleted data
static bool loop_ports_are_compatible(const LinearIR::LoopManagerPtr& loop_manager,
const size_t loop_lower_id,
const size_t loop_upper_id);
private:
static bool can_be_fused(const LinearIR::LoopManager::LoopInfoPtr& loop_current,
const LinearIR::LoopManager::LoopInfoPtr& loop_target);

View File

@ -0,0 +1,46 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "pass.hpp"
#include "snippets/lowered/loop_manager.hpp"
namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
/**
* @interface SplitLoops
* @brief If loop_1 has larger increment but the same works amount of loop_2, that follows loop_1, then split loop_2
* into two loops so the outermost of the split loops could be fused with the loop_1 using the pass `FuseLoops`.
* Example:
* Loop_1_begin Loop_1_begin Loop_1_begin
* ... ... ...
* Loop_1_end (wa = 128, inc = 32) Loop_1_end (wa = 128, inc = 32) Split_loop_2_begin
* ... Splitting ... Fusing ...
* Loop_2_begin => Split_loop_1_begin => Split_loop_2_end (wa = 32, inc = 1)
* ... Split_loop_2_begin ...
* Loop_2_end (wa = 128, inc = 1) ... Loop_1_end (wa = 128, inc = 32)
* Split_loop_2_end (wa = 32, inc = 1)
* Split_loop_1_end (wa = 128, inc = 32)
* @ingroup snippets
*/
class SplitLoops : public Pass {
public:
OPENVINO_RTTI("SplitLoops", "Pass")
SplitLoops();
bool run(LinearIR& linear_ir) override;
private:
static bool can_be_split(const LinearIR::LoopManager::LoopInfoPtr& current,
const LinearIR::LoopManager::LoopInfoPtr& target);
};
} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov

View File

@ -104,12 +104,14 @@ public:
ov::pass::Manager& pre_common,
ov::pass::Manager& post_common,
ov::pass::Manager& post_precision,
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
lowered::pass::PassPipeline& target_lowered_pipeline,
const void* compile_params = nullptr);
snippets::Schedule generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, const void* compile_params = nullptr);
snippets::Schedule generate(ov::pass::Manager& pre_common,
ov::pass::Manager& post_common,
ov::pass::Manager& post_precision,
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
lowered::pass::PassPipeline& target_lowered_pipeline,
const void* compile_params = nullptr);
snippets::Schedule generate(const void* compile_params = nullptr);
@ -144,7 +146,9 @@ public:
private:
void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes);
void data_flow_transformations(ov::pass::Manager& pre_common, ov::pass::Manager& post_common, ov::pass::Manager& post_precision);
void control_flow_transformations(lowered::LinearIR& linear_ir, lowered::pass::PassPipeline& target_pipeline);
void control_flow_transformations(lowered::LinearIR& linear_ir,
lowered::pass::PassPipeline& target_markup_pipeline,
lowered::pass::PassPipeline& target_pipeline);
void init_config();
// Count of Subgraph virtual ports:
// - Potential non-scalar Constants that will be created after some transformations (At the moment it's relevant only for FakeQuantize decomposition)

View File

@ -113,6 +113,18 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir,
}
}
LinearIR::LoopManager::LoopPort LinearIR::LoopManager::get_loop_port_by_expr_port(const ExpressionPort& expr_port, const size_t loop_id) {
auto get_loop_port = [&](const std::vector<LinearIR::LoopManager::LoopPort>& ports) {
auto it = std::find_if(ports.cbegin(), ports.cend(), [&](const LinearIR::LoopManager::LoopPort& p) { return *p.expr_port == expr_port; });
if (it == ports.cend())
OPENVINO_THROW("Expression has not been found among loop ports. Loop id: " + std::to_string(loop_id));
return *it;
};
const auto& loop_info = get_loop_info(loop_id);
return expr_port.get_type() == ExpressionPort::Input ? get_loop_port(loop_info->entry_points)
: get_loop_port(loop_info->exit_points);
}
void LinearIR::LoopManager::get_io_loop_ports(LinearIR::constExprIt loop_begin_pos,
LinearIR::constExprIt loop_end_pos,
std::vector<ExpressionPort> &entries,
@ -211,18 +223,6 @@ void LinearIR::LoopManager::mark_loop(LinearIR::constExprIt loop_begin_pos,
}
}
size_t LinearIR::LoopManager::mark_loop(LinearIR::constExprIt loop_begin_pos,
LinearIR::constExprIt loop_end_pos,
size_t work_amount, size_t work_amount_increment, size_t dim_idx,
const std::vector<ExpressionPort>& entries,
const std::vector<ExpressionPort>& exits) {
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, work_amount_increment, dim_idx, entries, exits);
const auto loop_id = this->add_loop_info(loop_info);
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
insert_loop_id(*expr_it, loop_id);
}
return loop_id;
}
void LinearIR::LoopManager::fuse_loops(const LinearIR& linear_ir, size_t loop_id_upper, size_t loop_id_lower, bool fuse_into_upper) {
LinearIR::constExprIt loop_begin_target, loop_end_target;
get_loop_bounds(linear_ir, fuse_into_upper ? loop_id_lower : loop_id_upper, loop_begin_target, loop_end_target);

View File

@ -24,6 +24,23 @@ using LoopInfoPtr = LoopManager::LoopInfoPtr;
FuseLoops::FuseLoops() : Pass() {}
bool FuseLoops::loop_ports_are_compatible(const LinearIR::LoopManagerPtr& loop_manager,
const size_t loop_lower_id,
const size_t loop_upper_id) {
const auto loop_lower = loop_manager->get_loop_info(loop_lower_id);
for (const auto& entry : loop_lower->entry_points) {
const auto& src_port = entry.expr_port->get_port_connector_ptr()->get_source();
if (is_loop_id_found(src_port.get_expr()->get_loop_ids(), loop_upper_id)) {
if (!entry.is_incremented)
return false;
auto src_loop_port = loop_manager->get_loop_port_by_expr_port(src_port, loop_upper_id);
if (!src_loop_port.is_incremented)
return false;
}
}
return true;
}
bool FuseLoops::can_be_fused(const LoopInfoPtr& loop_current, const LoopInfoPtr& loop_target) {
auto current_work_amount = loop_current->work_amount;
auto target_work_amount = loop_target->work_amount;
@ -79,7 +96,7 @@ bool FuseLoops::fuse_upper_into_current(LinearIR& linear_ir, const LinearIR::Loo
LinearIR::constExprIt& current_loop_begin_pos, LinearIR::constExprIt& current_loop_end_pos) {
const auto& loop_current = loop_manager->get_loop_info(current_loop_id);
const auto& loop_target = loop_manager->get_loop_info(target_loop_id);
if (!can_be_fused(loop_current, loop_target))
if (!can_be_fused(loop_current, loop_target) || !loop_ports_are_compatible(loop_manager, current_loop_id, target_loop_id))
return false;
// We can fuse Loop_up to Loop_down only in cases when other consumers of Loop_up are after Loop_down
@ -129,7 +146,7 @@ bool FuseLoops::fuse_lower_into_current(LinearIR& linear_ir, const LinearIR::Loo
LinearIR::constExprIt& current_loop_begin_pos, LinearIR::constExprIt& current_loop_end_pos) {
const auto& loop_current = loop_manager->get_loop_info(current_loop_id);
const auto& loop_target = loop_manager->get_loop_info(target_loop_id);
if (!can_be_fused(loop_current, loop_target))
if (!can_be_fused(loop_current, loop_target) || !loop_ports_are_compatible(loop_manager, target_loop_id, current_loop_id))
return false;
// We can fuse Loop_down to Loop_up only in cases when other parents of Loop_down are before Loop_up

View File

@ -51,10 +51,15 @@ void InitLoops::init_ptr_increments(std::vector<LoopPort>& loop_inputs, std::vec
const auto& layout = port->get_descriptor_ptr()->get_layout();
const auto& shape = port->get_descriptor_ptr()->get_shape();
const auto& dim = *(layout.rbegin() + dim_idx);
// Ticket: 113106
// WA: the current logic doesn't support the case with transposed output shape for brgemm layer
// but for all existing cases planar layout can be used
std::vector<size_t> planar(layout.size());
std::iota(planar.begin(), planar.end(), 0);
loop_output.ptr_increment = 0;
// If relevant dim is not broadcasted, then ptr_increment is the dim stride in the new layout
if (loop_output.is_incremented && !(shape[dim] == 1 && work_amount != 1)) {
loop_output.ptr_increment = get_dim_stride(dim, layout, shape);
loop_output.ptr_increment = get_dim_stride(dim, planar, shape);
}
}
}

View File

@ -4,9 +4,10 @@
#include "snippets/lowered/pass/insert_buffers.hpp"
#include "snippets/itt.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/itt.hpp"
#include "snippets/utils.hpp"
namespace ov {
@ -28,6 +29,49 @@ std::vector<size_t> get_buffer_loop_ids(const std::vector<size_t>& lhs, const st
}
return buffer_loop_ids;
}
// Ticket: 113744
// TODO: This logic covers only several specific cases so it should be generalized.
ov::Shape compute_allocation_shape(const LinearIR::LoopManagerPtr& loop_manager,
const std::vector<size_t>& buffer_loop_ids,
const std::vector<size_t>& parent_loop_ids,
const ov::Output<ov::Node>& parent_output,
const int allocation_rank) {
const size_t rank = allocation_rank >= 0 ? allocation_rank : parent_output.get_shape().size();
ov::Shape allocation_shape(rank);
const auto port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(parent_output);
const auto planar_shape = utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout());
for (size_t i = 0; i < rank; ++i) {
*(allocation_shape.rbegin() + i) = (planar_shape.rbegin() + i)->get_length();
}
if (buffer_loop_ids.empty() || parent_loop_ids.empty()) {
return allocation_shape;
}
auto set_rest_dims_to_ones = [&](const int filled_dims_count) {
for (int i = 0; i < static_cast<int>(allocation_shape.size()) - filled_dims_count; ++i) {
allocation_shape[i] = 1;
}
};
// In some cases it's possible to allocate less shape
// 1. Buffer and its parent are in the same loop: allocation size for the outer dimension can be extracted from loop increment
// 2. Buffer is outside the parent's loops: allocation size can be extracted from the corresponding loop work amount
// TODO: Use general logic with the help of memory counts for allocation shape computation
if (buffer_loop_ids.back() == parent_loop_ids.back()) {
const auto buffer_loop = loop_manager->get_loop_info(buffer_loop_ids.back());
*(allocation_shape.rbegin() + 1) = buffer_loop->increment;
set_rest_dims_to_ones(2);
} else {
for (size_t i = 0; i < std::min(rank, parent_loop_ids.size()); ++i) {
const auto loop = loop_manager->get_loop_info(*(parent_loop_ids.rbegin() + i));
*(allocation_shape.rbegin() + i) = loop->work_amount;
}
set_rest_dims_to_ones(static_cast<int>(parent_loop_ids.size()));
}
return allocation_shape;
}
} // namespace
InsertBuffers::InsertBuffers(int32_t buffer_allocation_rank)
@ -110,7 +154,12 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt
// Current expr Loop identifies: 3, 4, 6
// Need to insert between 2nd and 4th Loops - after 2nd Loop
const auto pos = insertion_position(linear_ir, loop_manager, parent_expr, expr);
const auto buffer = std::make_shared<op::Buffer>(parent->output(parent_port), m_buffer_allocation_rank);
const auto allocation_shape = compute_allocation_shape(loop_manager,
buffer_loop_ids,
parent_loops,
parent->output(parent_port),
m_buffer_allocation_rank);
const auto buffer = std::make_shared<op::Buffer>(parent->output(parent_port), allocation_shape);
PortDescriptorUtils::set_port_descriptor_ptr(buffer->output(0), parent_expr_output.get_descriptor_ptr()->clone());
// Output connector is automatically filled from PortDescriptor
const auto buffer_expr = linear_ir.create_expression(buffer, {input_connector});
@ -183,7 +232,12 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt
// Note: All potential consumers must have the same count of first equal Loop identifies and the same count of different last identifies
const auto pos = insertion_position(linear_ir, loop_manager, expr, (*potential_consumers.begin()).get_expr());
auto buffer = std::make_shared<op::Buffer>(node->output(port), m_buffer_allocation_rank);
const auto allocation_shape = compute_allocation_shape(loop_manager,
buffer_loop_ids,
current_loops,
node->output(port),
m_buffer_allocation_rank);
auto buffer = std::make_shared<op::Buffer>(node->output(port), allocation_shape);
PortDescriptorUtils::set_port_descriptor_ptr(buffer->output(0), exit_port->get_descriptor_ptr()->clone());
// We cannot insert Node output connector on Buffer output because not all consumers of Node needs Buffer
// Example:

View File

@ -0,0 +1,96 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/lowered/pass/split_loops.hpp"
#include "snippets/lowered/pass/fuse_loops.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/itt.hpp"
namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
using LoopManager = LinearIR::LoopManager;
using LoopInfoPtr = LoopManager::LoopInfoPtr;
SplitLoops::SplitLoops() : Pass() {}
bool SplitLoops::can_be_split(const LoopInfoPtr& current, const LoopInfoPtr& parent) {
return current->work_amount == parent->work_amount && current->dim_idx == parent->dim_idx &&
current->increment != parent->increment;
}
bool SplitLoops::run(LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::SplitLoops")
if (linear_ir.empty())
return false;
const auto& loop_manager = linear_ir.get_loop_manager();
bool loop_was_split = false;
for (const auto& expr : linear_ir) {
const auto& loop_ids = expr->get_loop_ids();
if (loop_ids.empty())
continue;
// Ticket: 113755
// Note: we currently consider only the outermost loops for splitting
// Splitting could also be done in a more general case, but the splitted loop and its parent must always
// be in the same set of outer loops. Otherwise they won't be fused.
const auto& loop_id = loop_ids.front();
const auto loop = loop_manager->get_loop_info(loop_id);
for (const auto& entry_point : loop->entry_points) {
const auto& parent_port = entry_point.expr_port->get_port_connector_ptr()->get_source();
const auto& parent_expr = parent_port.get_expr();
const auto parent_loop_ids = parent_expr->get_loop_ids();
if (parent_loop_ids.empty())
continue;
const auto& parent_loop_id = parent_loop_ids.front();
const auto parent_loop_port = loop_manager->get_loop_port_by_expr_port(parent_port, parent_loop_id);
// We don't split loop which are not compatible with parent loop because such loops will not be fused
if (!FuseLoops::loop_ports_are_compatible(loop_manager, loop_id, parent_loop_id))
continue;
const auto parent_loop = loop_manager->get_loop_info(parent_loop_id);
if (can_be_split(loop, parent_loop)) {
loop_was_split = true;
const bool split_parent = parent_loop->increment < loop->increment;
const auto& loop_to_split = split_parent ? parent_loop : loop;
const auto& loop_to_split_id = split_parent ? parent_loop_id : loop_id;
const auto& loop_to_fuse = !split_parent ? parent_loop : loop;
loop_to_split->work_amount = loop_to_fuse->increment;
LinearIR::constExprIt loop_begin_pos, loop_end_pos;
LoopManager::get_loop_bounds(linear_ir,
loop_to_split->entry_points,
loop_to_split->exit_points,
loop_begin_pos,
loop_end_pos,
loop_to_split_id);
const auto split_loop_id = loop_manager->mark_loop(loop_begin_pos,
loop_end_pos,
loop_to_fuse->work_amount,
loop_to_fuse->increment,
loop_to_split->dim_idx,
loop_to_split->entry_points,
loop_to_split->exit_points);
loop_manager->get_loop_info(split_loop_id)->outer_splited_loop = true;
break;
}
}
}
// Ticket: 113666
// FuseLoops pass is explicitly run here in order to avoid unnecessary computations
// in case if loops are not split but FuseLoops is registered in pass manager after SplitLoops
if (loop_was_split)
FuseLoops().run(linear_ir);
return loop_was_split;
}
} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov

View File

@ -24,6 +24,7 @@
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/pass/assign_registers.hpp"
#include "snippets/lowered/pass/mark_loops.hpp"
#include "snippets/lowered/pass/split_loops.hpp"
#include "snippets/lowered/pass/fuse_loops.hpp"
#include "snippets/lowered/pass/init_loops.hpp"
#include "snippets/lowered/pass/insert_buffers.hpp"
@ -507,6 +508,7 @@ void snippets::op::Subgraph::data_flow_transformations(ov::pass::Manager& pre_co
}
void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,
lowered::pass::PassPipeline& target_markup_pipeline,
lowered::pass::PassPipeline& target_pipeline) {
INTERNAL_OP_SCOPE(Subgraph);
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::control_flow_transformations")
@ -514,10 +516,15 @@ void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& lin
const size_t vector_size = get_generator()->get_target_machine()->get_lanes();
const int32_t buffer_allocation_rank = static_cast<int32_t>(linear_ir.get_config().m_loop_depth);
// Ticket: 113666
// TODO: Make pass pipeline with backend passes more flexible
target_markup_pipeline.run(linear_ir);
lowered::pass::PassPipeline common_pipeline;
common_pipeline.register_pass<lowered::pass::MarkLoops>(vector_size);
common_pipeline.register_pass<lowered::pass::SoftmaxDecomposition>(vector_size);
common_pipeline.register_pass<lowered::pass::FuseLoops>();
common_pipeline.register_pass<lowered::pass::SplitLoops>();
common_pipeline.register_pass<lowered::pass::MoveResultOutOfLoop>();
common_pipeline.register_pass<lowered::pass::InsertBuffers>(buffer_allocation_rank);
common_pipeline.register_pass<lowered::pass::InsertLoadStore>(vector_size);
@ -557,22 +564,24 @@ snippets::Schedule snippets::op::Subgraph::generate(const BlockedShapeVector& ou
ov::pass::Manager& pre_common,
ov::pass::Manager& post_common,
ov::pass::Manager& post_precision,
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
lowered::pass::PassPipeline& target_lowered_pipeline,
const void* compile_params) {
canonicalize(output_shapes, input_shapes);
return generate(pre_common, post_common, post_precision, target_lowered_pipeline, compile_params);
return generate(pre_common, post_common, post_precision, target_lowered_markup_pipeline, target_lowered_pipeline, compile_params);
}
snippets::Schedule snippets::op::Subgraph::generate(const void* compile_params) {
auto mngr = ov::pass::Manager();
auto lowered = lowered::pass::PassPipeline();
return generate(mngr, mngr, mngr, lowered, compile_params);
return generate(mngr, mngr, mngr, lowered, lowered, compile_params);
}
snippets::Schedule snippets::op::Subgraph::generate(
ov::pass::Manager& pre_common,
ov::pass::Manager& post_common,
ov::pass::Manager& post_precision,
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
lowered::pass::PassPipeline& target_lowered_pipeline,
const void* compile_params) {
INTERNAL_OP_SCOPE(Subgraph);
@ -587,7 +596,7 @@ snippets::Schedule snippets::op::Subgraph::generate(
lowering_config.m_loop_depth = tileRank;
lowered::LinearIR linear_ir = lowered::LinearIR(body_ptr(), lowering_config);
control_flow_transformations(linear_ir, target_lowered_pipeline);
control_flow_transformations(linear_ir, target_lowered_markup_pipeline, target_lowered_pipeline);
// actual code emission
const auto& lowering_result = m_generator->generate(linear_ir, lowering_config, compile_params);

View File

@ -126,7 +126,8 @@ std::shared_ptr<ov::snippets::op::Subgraph> LoweringTests::getLoweredSubgraph(co
}
body_rt_info["PluginShapesOverride"] = new_shapes;
subgraph->set_tile_rank(2);
subgraph->generate(pre_dialect, post_precision, post_precision, lowered_pipeline);
ov::snippets::lowered::pass::PassPipeline empty_pipeline;
subgraph->generate(pre_dialect, post_precision, post_precision, empty_pipeline, lowered_pipeline);
return subgraph;
}

View File

@ -701,8 +701,8 @@ void StoreConvertEmitter::emit_isa(const std::vector<size_t> &in, const std::vec
void StoreConvertEmitter::emit_data() const {
store_emitter->emit_data();
}
size_t BrgemmEmitter::getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const {
return mIdx * 4 + kIdx * 2 + nIdx;
size_t BrgemmEmitter::getBrgIdx(size_t kIdx, size_t nIdx) const {
return kIdx * 2 + nIdx;
}
BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa,
const std::shared_ptr<ov::Node>& node) : jit_emitter(h, isa, node) {
@ -758,10 +758,8 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
return std::distance(layout.begin(), std::find(layout.begin(), layout.end(), idx));
};
m_M = C_shape[get_ordered_idx(C_layout, C_layout.size() - 2)];
m_K = A_shape[get_ordered_idx(A_layout, A_layout.size() - 1)];
m_M_blk = matmulOptimalM;
m_M_tail = m_M % m_M_blk;
m_M = brgemm_node->get_input_count(0);
m_N = C_shape[get_ordered_idx(C_layout, C_layout.size() - 1)];
auto brg0Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(0));
@ -780,34 +778,28 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
: m_K;
m_K_tail = m_K % m_K_blk;
size_t brg0BaseIdx = std::numeric_limits<size_t>::max();
for (size_t m = 0; m < 2; m++) {
for (size_t k = 0; k < 2; k++) {
for (size_t n = 0; n < 2; n++) {
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(m, k, n)];
for (size_t k = 0; k < 2; k++) {
for (size_t n = 0; n < 2; n++) {
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(k, n)];
auto M_ = m ? m_M_tail
: m_M < m_M_blk ? 0 : m_M_blk;
auto N_ = n ? m_N_tail : m_N - m_N_tail;
auto K_ = k ? m_K_tail : m_K - m_K_tail;
auto beta = k && m_brgCtxs0[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f;
auto M_ = m_M;
auto N_ = n ? m_N_tail : m_N - m_N_tail;
auto K_ = k ? m_K_tail : m_K - m_K_tail;
auto beta = k && m_brgCtxs0[getBrgIdx(0, n)].K != 0 ? 1.0f : 0.0f;
brgemmCtx.M = M_;
brgemmCtx.N = N_;
brgemmCtx.K = K_;
brgemmCtx.LDA = leading_dimensions[0];
brgemmCtx.LDB = brgemm_node->is_with_data_repacking() ? rnd_up(m_N, m_N_blk) : leading_dimensions[1];
brgemmCtx.LDC = leading_dimensions[2];
brgemmCtx.dt_in0 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc));
brgemmCtx.dt_in1 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg1Prc));
brgemmCtx.beta = beta;
brgemmCtx.M = M_;
brgemmCtx.N = N_;
brgemmCtx.K = K_;
brgemmCtx.LDA = leading_dimensions[0];
brgemmCtx.LDB = brgemm_node->is_with_data_repacking() ? rnd_up(m_N, m_N_blk) : leading_dimensions[1];
brgemmCtx.LDC = leading_dimensions[2];
brgemmCtx.dt_in0 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc));
brgemmCtx.dt_in1 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg1Prc));
brgemmCtx.beta = beta;
// don't create brgemm kernels for empty tiles
if (M_ != 0 && K_ != 0 && N_ != 0) {
if (brg0BaseIdx == std::numeric_limits<size_t>::max())
brg0BaseIdx = getBrgIdx(m, k, n);
initBrgemm(brgemmCtx, m_brgKernels0[getBrgIdx(m, k, n)], brgWithAMX);
}
// don't create brgemm kernels for empty tiles
if (M_ != 0 && K_ != 0 && N_ != 0) {
initBrgemm(brgemmCtx, m_brgKernels0[getBrgIdx(k, n)], brgWithAMX);
}
}
}
@ -878,36 +870,31 @@ void BrgemmEmitter::emit_impl(const std::vector<size_t>& in,
}
Xbyak::Reg64 output_0(static_cast<int>(out[0]));
for (size_t mb = 0; mb < div_up(m_M, m_M_blk); mb++) {
const bool is_M_tail = (m_M - mb * m_M_blk < m_M_blk);
size_t brgIdx0 = getBrgIdx(0, 0);
size_t K0_step0 = m_brgCtxs0[brgIdx0].K;
size_t K0_step1 = m_brgCtxs0[brgIdx0].K * m_brgCtxs0[brgIdx0].LDB;
size_t N0_step0 = m_brgCtxs0[brgIdx0].N * m_brg0VnniFactor;
size_t N0_step1 = m_brgCtxs0[brgIdx0].N;
for (size_t n = 0; n < 2; n++) {
for (size_t k = 0; k < 2; k++) {
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(k, n)];
size_t brgIdx0 = getBrgIdx(0, 0, 0);
size_t K0_step0 = m_brgCtxs0[brgIdx0].K;
size_t K0_step1 = m_brgCtxs0[brgIdx0].K * m_brgCtxs0[brgIdx0].LDB;
size_t N0_step0 = m_brgCtxs0[brgIdx0].N * m_brg0VnniFactor;
size_t N0_step1 = m_brgCtxs0[brgIdx0].N;
for (size_t n = 0; n < 2; n++) {
for (size_t k = 0; k < 2; k++) {
size_t mIdx = is_M_tail ? 1 : 0;
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(mIdx, k, n)];
if (brgemmCtx.K != 0 && brgemmCtx.N != 0) {
const size_t in0_offset = m_load_offset_a + k * K0_step0 * io_data_size[0];
const size_t in1_offset = m_load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1];
const size_t in2_offset = m_load_offset_scratch + (m_with_comp ? n * N0_step1 * sizeof(int32_t) : 0);
const size_t out0_offset = m_store_offset_c + n * N0_step1 * io_data_size[2];
if (brgemmCtx.K != 0 && brgemmCtx.N != 0) {
const size_t in0_offset = m_load_offset_a + (k * K0_step0 + mb * m_M_blk * brgemmCtx.LDA) * io_data_size[0];
const size_t in1_offset = m_load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1];
const size_t in2_offset = m_load_offset_scratch + (m_with_comp ? n * N0_step1 * sizeof(int32_t) : 0);
const size_t out0_offset = m_store_offset_c + (n * N0_step1 + mb * m_M_blk * brgemmCtx.LDC) * io_data_size[2];
emit_brgemm_kernel_call(m_brgKernels0[getBrgIdx(mIdx, k, n)].get(),
brgemmCtx,
input_0,
input_1,
input_2,
output_0,
in0_offset,
in1_offset,
in2_offset,
out0_offset);
}
emit_brgemm_kernel_call(m_brgKernels0[getBrgIdx(k, n)].get(),
brgemmCtx,
input_0,
input_1,
input_2,
output_0,
in0_offset,
in1_offset,
in2_offset,
out0_offset);
}
}
}

View File

@ -353,7 +353,7 @@ private:
float beta;
};
void initBrgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel, bool use_amx) const;
size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const;
size_t getBrgIdx(size_t kIdx, size_t nIdx) const;
void emit_brgemm_kernel_call(const dnnl::impl::cpu::x64::brgemm_kernel_t* brg_kernel, const brgemmCtx& ctx,
Xbyak::Reg64 addr_A, Xbyak::Reg64 addr_B, Xbyak::Reg64 scratch, Xbyak::Reg64 addr_C,
@ -362,11 +362,10 @@ private:
static void kernel_execute(const dnnl::impl::cpu::x64::brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C, void *scratch, int with_comp);
static constexpr size_t BRGEMM_KERNELS_NUM = 8;
static constexpr size_t matmulOptimalM = 32;
brgemmCtx m_brgCtxs0[BRGEMM_KERNELS_NUM];
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> m_brgKernels0[BRGEMM_KERNELS_NUM];
size_t m_M, m_M_blk, m_M_tail;
size_t m_M;
size_t m_K, m_K_blk, m_K_tail;
size_t m_N, m_N_blk, m_N_tail;
size_t m_brg0VnniFactor;

View File

@ -25,6 +25,7 @@
#include "utils/cpu_utils.hpp"
#include "emitters/x64/cpu_generator.hpp"
#include "transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp"
#include "transformations/snippets/x64/pass/lowered/brgemm_blocking.hpp"
#include "transformations/snippets/x64/pass/mul_add_to_fma.hpp"
#include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp"
#include "transformations/snippets/x64/pass/remove_converts.hpp"
@ -564,6 +565,9 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) {
CPU_REGISTER_PASS_X64(post_precision, ov::intel_cpu::pass::RemoveConverts);
CPU_REGISTER_PASS_X64(post_precision, ov::intel_cpu::pass::MulAddToFMA);
ov::snippets::lowered::pass::PassPipeline control_flow_markup_pipeline;
CPU_REGISTER_PASS_X64(control_flow_markup_pipeline, ov::intel_cpu::pass::BrgemmBlocking);
ov::snippets::lowered::pass::PassPipeline control_flow_pipeline;
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert);
@ -571,6 +575,7 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) {
pre_dialect,
post_dialect,
post_precision,
control_flow_markup_pipeline,
control_flow_pipeline,
reinterpret_cast<const void*>(jcp));
}

View File

@ -0,0 +1,80 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "brgemm_blocking.hpp"
#include "openvino/pass/pattern/matcher.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "snippets/itt.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/snippets_isa.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
namespace ov {
namespace intel_cpu {
namespace pass {
using LoopManager = snippets::lowered::LinearIR::LoopManager;
using LoopInfoPtr = LoopManager::LoopInfoPtr;
using LoopPort = LoopManager::LoopPort;
BrgemmBlocking::BrgemmBlocking() : Pass() {}
bool BrgemmBlocking::run(snippets::lowered::LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BrgemmBlocking")
if (linear_ir.empty())
return false;
// Ticket: 113745
// TODO: make the block size configurable
const auto block_size = 32;
const auto dim_idx = 1;
const auto& loop_manager = linear_ir.get_loop_manager();
auto blocking_loop_exists = [&](const ov::snippets::lowered::ExpressionPtr& expr,
const std::shared_ptr<ov::intel_cpu::BrgemmCPU>& brgemm) {
const auto& loop_ids = expr->get_loop_ids();
for (const auto& id : loop_ids) {
const auto loop = loop_manager->get_loop_info(id);
if (loop->dim_idx == dim_idx) {
OPENVINO_ASSERT(brgemm->get_input_count(0) == loop->increment,
"Brgemm ", brgemm, " has input count (", brgemm->get_input_count(0),
") which doesn't match the increment(", loop->increment, ") of loop by M");
return true;
}
}
return false;
};
bool modified = false;
for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
const auto& expr = *expr_it;
const auto brgemm = ov::as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
if (!brgemm || blocking_loop_exists(expr, brgemm))
continue;
const auto& input_shape_0 = expr->get_input_port_descriptor(0)->get_shape();
const auto& input_layout_0 = expr->get_input_port_descriptor(0)->get_layout();
const auto& dim = *(input_layout_0.rbegin() + dim_idx);
const auto& m = input_shape_0[dim];
brgemm->set_input_count(block_size);
const auto work_amount = m;
const auto increment = block_size;
std::vector<LoopPort> entries{LoopPort(expr->get_input_port(0), true), LoopPort(expr->get_input_port(1), false)};
if (brgemm->is_with_scratchpad())
entries.emplace_back(expr->get_input_port(2), false);
std::vector<LoopPort> exits{LoopPort(expr->get_output_port(0), true)};
loop_manager->mark_loop(expr_it, std::next(expr_it), work_amount, increment, dim_idx, entries, exits);
}
return modified;
}
} // namespace pass
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,28 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "snippets/lowered/pass/pass.hpp"
namespace ov {
namespace intel_cpu {
namespace pass {
/**
* @interface BrgemmBlocking
* @brief Covers BrgemmCPU with blocking loop by M
* @ingroup snippets
*/
class BrgemmBlocking : public snippets::lowered::pass::Pass {
public:
OPENVINO_RTTI("BrgemmBlocking", "Pass")
BrgemmBlocking();
bool run(snippets::lowered::LinearIR& linear_ir) override;
};
} // namespace pass
} // namespace intel_cpu
} // namespace ov

View File

@ -21,7 +21,7 @@ namespace pass {
class FuseLoadStoreConvert: public snippets::lowered::pass::Pass {
public:
FuseLoadStoreConvert() = default;
OPENVINO_RTTI("FuseLoadStoreConvert", "LinearIRTransformation");
OPENVINO_RTTI("FuseLoadStoreConvert", "Pass");
bool run(snippets::lowered::LinearIR& linear_ir) override;
private: