[Snippets] MatMul: blocking by M dimension at the LIR stage (#18169)
* [Snippets] MatMul: blocking by M dim at LIR level * Alexandra's comments applied * Ivan's comments applied * Fix warning
This commit is contained in:
parent
3bc8065ca3
commit
18e737493c
@ -76,11 +76,19 @@ public:
|
||||
LinearIR::constExprIt loop_end_pos,
|
||||
size_t loop_depth, size_t vector_size);
|
||||
// Return Loop ID
|
||||
template <typename T>
|
||||
size_t mark_loop(LinearIR::constExprIt loop_begin_pos,
|
||||
LinearIR::constExprIt loop_end_pos,
|
||||
size_t work_amount, size_t work_amount_increment, size_t dim_idx,
|
||||
const std::vector<ExpressionPort>& entries,
|
||||
const std::vector<ExpressionPort>& exits);
|
||||
const std::vector<T>& entries,
|
||||
const std::vector<T>& exits) {
|
||||
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, work_amount_increment, dim_idx, entries, exits);
|
||||
const auto loop_id = this->add_loop_info(loop_info);
|
||||
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
|
||||
insert_loop_id(*expr_it, loop_id);
|
||||
}
|
||||
return loop_id;
|
||||
}
|
||||
|
||||
void fuse_loops(const LinearIR& linear_ir, size_t loop_id_upper, size_t loop_id_lower, bool fuse_into_upper = true);
|
||||
void fuse_loops(LinearIR::constExprIt loop_begin_target, LinearIR::constExprIt loop_end_target,
|
||||
@ -123,6 +131,8 @@ public:
|
||||
LinearIR::constExprIt& loop_end_pos,
|
||||
size_t loop_id, bool loop_ops_inserted = false);
|
||||
|
||||
LoopPort get_loop_port_by_expr_port(const ExpressionPort& expr_port, const size_t loop_id);
|
||||
|
||||
private:
|
||||
static void get_io_loop_ports(LinearIR::constExprIt loop_begin_pos,
|
||||
LinearIR::constExprIt loop_end_pos,
|
||||
|
@ -42,6 +42,12 @@ public:
|
||||
FuseLoops();
|
||||
bool run(LinearIR& linear_ir) override;
|
||||
|
||||
// This method checks that all ports which connect lower and upper loops are incremented.
|
||||
// This helps to avoid fusing for the ports with incompleted data
|
||||
static bool loop_ports_are_compatible(const LinearIR::LoopManagerPtr& loop_manager,
|
||||
const size_t loop_lower_id,
|
||||
const size_t loop_upper_id);
|
||||
|
||||
private:
|
||||
static bool can_be_fused(const LinearIR::LoopManager::LoopInfoPtr& loop_current,
|
||||
const LinearIR::LoopManager::LoopInfoPtr& loop_target);
|
||||
|
@ -0,0 +1,46 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "pass.hpp"
|
||||
#include "snippets/lowered/loop_manager.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
namespace lowered {
|
||||
namespace pass {
|
||||
|
||||
/**
|
||||
* @interface SplitLoops
|
||||
* @brief If loop_1 has larger increment but the same works amount of loop_2, that follows loop_1, then split loop_2
|
||||
* into two loops so the outermost of the split loops could be fused with the loop_1 using the pass `FuseLoops`.
|
||||
* Example:
|
||||
* Loop_1_begin Loop_1_begin Loop_1_begin
|
||||
* ... ... ...
|
||||
* Loop_1_end (wa = 128, inc = 32) Loop_1_end (wa = 128, inc = 32) Split_loop_2_begin
|
||||
* ... Splitting ... Fusing ...
|
||||
* Loop_2_begin => Split_loop_1_begin => Split_loop_2_end (wa = 32, inc = 1)
|
||||
* ... Split_loop_2_begin ...
|
||||
* Loop_2_end (wa = 128, inc = 1) ... Loop_1_end (wa = 128, inc = 32)
|
||||
* Split_loop_2_end (wa = 32, inc = 1)
|
||||
* Split_loop_1_end (wa = 128, inc = 32)
|
||||
* @ingroup snippets
|
||||
*/
|
||||
|
||||
class SplitLoops : public Pass {
|
||||
public:
|
||||
OPENVINO_RTTI("SplitLoops", "Pass")
|
||||
SplitLoops();
|
||||
bool run(LinearIR& linear_ir) override;
|
||||
|
||||
private:
|
||||
static bool can_be_split(const LinearIR::LoopManager::LoopInfoPtr& current,
|
||||
const LinearIR::LoopManager::LoopInfoPtr& target);
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace lowered
|
||||
} // namespace snippets
|
||||
} // namespace ov
|
@ -104,12 +104,14 @@ public:
|
||||
ov::pass::Manager& pre_common,
|
||||
ov::pass::Manager& post_common,
|
||||
ov::pass::Manager& post_precision,
|
||||
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_lowered_pipeline,
|
||||
const void* compile_params = nullptr);
|
||||
snippets::Schedule generate(const BlockedShapeVector& output_shapes, const BlockedShapeVector& input_shapes, const void* compile_params = nullptr);
|
||||
snippets::Schedule generate(ov::pass::Manager& pre_common,
|
||||
ov::pass::Manager& post_common,
|
||||
ov::pass::Manager& post_precision,
|
||||
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_lowered_pipeline,
|
||||
const void* compile_params = nullptr);
|
||||
snippets::Schedule generate(const void* compile_params = nullptr);
|
||||
@ -144,7 +146,9 @@ public:
|
||||
private:
|
||||
void align_element_types(const BlockedShapeVector& outputShapes, const BlockedShapeVector& inputShapes);
|
||||
void data_flow_transformations(ov::pass::Manager& pre_common, ov::pass::Manager& post_common, ov::pass::Manager& post_precision);
|
||||
void control_flow_transformations(lowered::LinearIR& linear_ir, lowered::pass::PassPipeline& target_pipeline);
|
||||
void control_flow_transformations(lowered::LinearIR& linear_ir,
|
||||
lowered::pass::PassPipeline& target_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_pipeline);
|
||||
void init_config();
|
||||
// Count of Subgraph virtual ports:
|
||||
// - Potential non-scalar Constants that will be created after some transformations (At the moment it's relevant only for FakeQuantize decomposition)
|
||||
|
@ -113,6 +113,18 @@ void LinearIR::LoopManager::get_loop_bounds(const LinearIR &linear_ir,
|
||||
}
|
||||
}
|
||||
|
||||
LinearIR::LoopManager::LoopPort LinearIR::LoopManager::get_loop_port_by_expr_port(const ExpressionPort& expr_port, const size_t loop_id) {
|
||||
auto get_loop_port = [&](const std::vector<LinearIR::LoopManager::LoopPort>& ports) {
|
||||
auto it = std::find_if(ports.cbegin(), ports.cend(), [&](const LinearIR::LoopManager::LoopPort& p) { return *p.expr_port == expr_port; });
|
||||
if (it == ports.cend())
|
||||
OPENVINO_THROW("Expression has not been found among loop ports. Loop id: " + std::to_string(loop_id));
|
||||
return *it;
|
||||
};
|
||||
const auto& loop_info = get_loop_info(loop_id);
|
||||
return expr_port.get_type() == ExpressionPort::Input ? get_loop_port(loop_info->entry_points)
|
||||
: get_loop_port(loop_info->exit_points);
|
||||
}
|
||||
|
||||
void LinearIR::LoopManager::get_io_loop_ports(LinearIR::constExprIt loop_begin_pos,
|
||||
LinearIR::constExprIt loop_end_pos,
|
||||
std::vector<ExpressionPort> &entries,
|
||||
@ -211,18 +223,6 @@ void LinearIR::LoopManager::mark_loop(LinearIR::constExprIt loop_begin_pos,
|
||||
}
|
||||
}
|
||||
|
||||
size_t LinearIR::LoopManager::mark_loop(LinearIR::constExprIt loop_begin_pos,
|
||||
LinearIR::constExprIt loop_end_pos,
|
||||
size_t work_amount, size_t work_amount_increment, size_t dim_idx,
|
||||
const std::vector<ExpressionPort>& entries,
|
||||
const std::vector<ExpressionPort>& exits) {
|
||||
const auto loop_info = std::make_shared<LoopManager::LoopInfo>(work_amount, work_amount_increment, dim_idx, entries, exits);
|
||||
const auto loop_id = this->add_loop_info(loop_info);
|
||||
for (auto expr_it = loop_begin_pos; expr_it != loop_end_pos; ++expr_it) {
|
||||
insert_loop_id(*expr_it, loop_id);
|
||||
}
|
||||
return loop_id;
|
||||
}
|
||||
void LinearIR::LoopManager::fuse_loops(const LinearIR& linear_ir, size_t loop_id_upper, size_t loop_id_lower, bool fuse_into_upper) {
|
||||
LinearIR::constExprIt loop_begin_target, loop_end_target;
|
||||
get_loop_bounds(linear_ir, fuse_into_upper ? loop_id_lower : loop_id_upper, loop_begin_target, loop_end_target);
|
||||
|
@ -24,6 +24,23 @@ using LoopInfoPtr = LoopManager::LoopInfoPtr;
|
||||
|
||||
FuseLoops::FuseLoops() : Pass() {}
|
||||
|
||||
bool FuseLoops::loop_ports_are_compatible(const LinearIR::LoopManagerPtr& loop_manager,
|
||||
const size_t loop_lower_id,
|
||||
const size_t loop_upper_id) {
|
||||
const auto loop_lower = loop_manager->get_loop_info(loop_lower_id);
|
||||
for (const auto& entry : loop_lower->entry_points) {
|
||||
const auto& src_port = entry.expr_port->get_port_connector_ptr()->get_source();
|
||||
if (is_loop_id_found(src_port.get_expr()->get_loop_ids(), loop_upper_id)) {
|
||||
if (!entry.is_incremented)
|
||||
return false;
|
||||
auto src_loop_port = loop_manager->get_loop_port_by_expr_port(src_port, loop_upper_id);
|
||||
if (!src_loop_port.is_incremented)
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool FuseLoops::can_be_fused(const LoopInfoPtr& loop_current, const LoopInfoPtr& loop_target) {
|
||||
auto current_work_amount = loop_current->work_amount;
|
||||
auto target_work_amount = loop_target->work_amount;
|
||||
@ -79,7 +96,7 @@ bool FuseLoops::fuse_upper_into_current(LinearIR& linear_ir, const LinearIR::Loo
|
||||
LinearIR::constExprIt& current_loop_begin_pos, LinearIR::constExprIt& current_loop_end_pos) {
|
||||
const auto& loop_current = loop_manager->get_loop_info(current_loop_id);
|
||||
const auto& loop_target = loop_manager->get_loop_info(target_loop_id);
|
||||
if (!can_be_fused(loop_current, loop_target))
|
||||
if (!can_be_fused(loop_current, loop_target) || !loop_ports_are_compatible(loop_manager, current_loop_id, target_loop_id))
|
||||
return false;
|
||||
|
||||
// We can fuse Loop_up to Loop_down only in cases when other consumers of Loop_up are after Loop_down
|
||||
@ -129,7 +146,7 @@ bool FuseLoops::fuse_lower_into_current(LinearIR& linear_ir, const LinearIR::Loo
|
||||
LinearIR::constExprIt& current_loop_begin_pos, LinearIR::constExprIt& current_loop_end_pos) {
|
||||
const auto& loop_current = loop_manager->get_loop_info(current_loop_id);
|
||||
const auto& loop_target = loop_manager->get_loop_info(target_loop_id);
|
||||
if (!can_be_fused(loop_current, loop_target))
|
||||
if (!can_be_fused(loop_current, loop_target) || !loop_ports_are_compatible(loop_manager, target_loop_id, current_loop_id))
|
||||
return false;
|
||||
|
||||
// We can fuse Loop_down to Loop_up only in cases when other parents of Loop_down are before Loop_up
|
||||
|
@ -51,10 +51,15 @@ void InitLoops::init_ptr_increments(std::vector<LoopPort>& loop_inputs, std::vec
|
||||
const auto& layout = port->get_descriptor_ptr()->get_layout();
|
||||
const auto& shape = port->get_descriptor_ptr()->get_shape();
|
||||
const auto& dim = *(layout.rbegin() + dim_idx);
|
||||
// Ticket: 113106
|
||||
// WA: the current logic doesn't support the case with transposed output shape for brgemm layer
|
||||
// but for all existing cases planar layout can be used
|
||||
std::vector<size_t> planar(layout.size());
|
||||
std::iota(planar.begin(), planar.end(), 0);
|
||||
loop_output.ptr_increment = 0;
|
||||
// If relevant dim is not broadcasted, then ptr_increment is the dim stride in the new layout
|
||||
if (loop_output.is_incremented && !(shape[dim] == 1 && work_amount != 1)) {
|
||||
loop_output.ptr_increment = get_dim_stride(dim, layout, shape);
|
||||
loop_output.ptr_increment = get_dim_stride(dim, planar, shape);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -4,9 +4,10 @@
|
||||
|
||||
#include "snippets/lowered/pass/insert_buffers.hpp"
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/lowered/linear_ir.hpp"
|
||||
#include "snippets/snippets_isa.hpp"
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
|
||||
|
||||
namespace ov {
|
||||
@ -28,6 +29,49 @@ std::vector<size_t> get_buffer_loop_ids(const std::vector<size_t>& lhs, const st
|
||||
}
|
||||
return buffer_loop_ids;
|
||||
}
|
||||
|
||||
// Ticket: 113744
|
||||
// TODO: This logic covers only several specific cases so it should be generalized.
|
||||
ov::Shape compute_allocation_shape(const LinearIR::LoopManagerPtr& loop_manager,
|
||||
const std::vector<size_t>& buffer_loop_ids,
|
||||
const std::vector<size_t>& parent_loop_ids,
|
||||
const ov::Output<ov::Node>& parent_output,
|
||||
const int allocation_rank) {
|
||||
const size_t rank = allocation_rank >= 0 ? allocation_rank : parent_output.get_shape().size();
|
||||
ov::Shape allocation_shape(rank);
|
||||
const auto port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(parent_output);
|
||||
const auto planar_shape = utils::get_reordered_planar_shape(ov::Shape{port->get_shape()}, port->get_layout());
|
||||
for (size_t i = 0; i < rank; ++i) {
|
||||
*(allocation_shape.rbegin() + i) = (planar_shape.rbegin() + i)->get_length();
|
||||
}
|
||||
|
||||
if (buffer_loop_ids.empty() || parent_loop_ids.empty()) {
|
||||
return allocation_shape;
|
||||
}
|
||||
|
||||
auto set_rest_dims_to_ones = [&](const int filled_dims_count) {
|
||||
for (int i = 0; i < static_cast<int>(allocation_shape.size()) - filled_dims_count; ++i) {
|
||||
allocation_shape[i] = 1;
|
||||
}
|
||||
};
|
||||
|
||||
// In some cases it's possible to allocate less shape
|
||||
// 1. Buffer and its parent are in the same loop: allocation size for the outer dimension can be extracted from loop increment
|
||||
// 2. Buffer is outside the parent's loops: allocation size can be extracted from the corresponding loop work amount
|
||||
// TODO: Use general logic with the help of memory counts for allocation shape computation
|
||||
if (buffer_loop_ids.back() == parent_loop_ids.back()) {
|
||||
const auto buffer_loop = loop_manager->get_loop_info(buffer_loop_ids.back());
|
||||
*(allocation_shape.rbegin() + 1) = buffer_loop->increment;
|
||||
set_rest_dims_to_ones(2);
|
||||
} else {
|
||||
for (size_t i = 0; i < std::min(rank, parent_loop_ids.size()); ++i) {
|
||||
const auto loop = loop_manager->get_loop_info(*(parent_loop_ids.rbegin() + i));
|
||||
*(allocation_shape.rbegin() + i) = loop->work_amount;
|
||||
}
|
||||
set_rest_dims_to_ones(static_cast<int>(parent_loop_ids.size()));
|
||||
}
|
||||
return allocation_shape;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
InsertBuffers::InsertBuffers(int32_t buffer_allocation_rank)
|
||||
@ -110,7 +154,12 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt
|
||||
// Current expr Loop identifies: 3, 4, 6
|
||||
// Need to insert between 2nd and 4th Loops - after 2nd Loop
|
||||
const auto pos = insertion_position(linear_ir, loop_manager, parent_expr, expr);
|
||||
const auto buffer = std::make_shared<op::Buffer>(parent->output(parent_port), m_buffer_allocation_rank);
|
||||
const auto allocation_shape = compute_allocation_shape(loop_manager,
|
||||
buffer_loop_ids,
|
||||
parent_loops,
|
||||
parent->output(parent_port),
|
||||
m_buffer_allocation_rank);
|
||||
const auto buffer = std::make_shared<op::Buffer>(parent->output(parent_port), allocation_shape);
|
||||
PortDescriptorUtils::set_port_descriptor_ptr(buffer->output(0), parent_expr_output.get_descriptor_ptr()->clone());
|
||||
// Output connector is automatically filled from PortDescriptor
|
||||
const auto buffer_expr = linear_ir.create_expression(buffer, {input_connector});
|
||||
@ -183,7 +232,12 @@ void InsertBuffers::insertion(LinearIR& linear_ir, const LinearIR::LoopManagerPt
|
||||
// Note: All potential consumers must have the same count of first equal Loop identifies and the same count of different last identifies
|
||||
const auto pos = insertion_position(linear_ir, loop_manager, expr, (*potential_consumers.begin()).get_expr());
|
||||
|
||||
auto buffer = std::make_shared<op::Buffer>(node->output(port), m_buffer_allocation_rank);
|
||||
const auto allocation_shape = compute_allocation_shape(loop_manager,
|
||||
buffer_loop_ids,
|
||||
current_loops,
|
||||
node->output(port),
|
||||
m_buffer_allocation_rank);
|
||||
auto buffer = std::make_shared<op::Buffer>(node->output(port), allocation_shape);
|
||||
PortDescriptorUtils::set_port_descriptor_ptr(buffer->output(0), exit_port->get_descriptor_ptr()->clone());
|
||||
// We cannot insert Node output connector on Buffer output because not all consumers of Node needs Buffer
|
||||
// Example:
|
||||
|
96
src/common/snippets/src/lowered/pass/split_loops.cpp
Normal file
96
src/common/snippets/src/lowered/pass/split_loops.cpp
Normal file
@ -0,0 +1,96 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "snippets/lowered/pass/split_loops.hpp"
|
||||
|
||||
#include "snippets/lowered/pass/fuse_loops.hpp"
|
||||
#include "snippets/lowered/linear_ir.hpp"
|
||||
#include "snippets/lowered/loop_manager.hpp"
|
||||
#include "snippets/snippets_isa.hpp"
|
||||
#include "snippets/itt.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
namespace lowered {
|
||||
namespace pass {
|
||||
using LoopManager = LinearIR::LoopManager;
|
||||
using LoopInfoPtr = LoopManager::LoopInfoPtr;
|
||||
|
||||
SplitLoops::SplitLoops() : Pass() {}
|
||||
|
||||
bool SplitLoops::can_be_split(const LoopInfoPtr& current, const LoopInfoPtr& parent) {
|
||||
return current->work_amount == parent->work_amount && current->dim_idx == parent->dim_idx &&
|
||||
current->increment != parent->increment;
|
||||
}
|
||||
|
||||
bool SplitLoops::run(LinearIR& linear_ir) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::SplitLoops")
|
||||
if (linear_ir.empty())
|
||||
return false;
|
||||
|
||||
const auto& loop_manager = linear_ir.get_loop_manager();
|
||||
bool loop_was_split = false;
|
||||
for (const auto& expr : linear_ir) {
|
||||
const auto& loop_ids = expr->get_loop_ids();
|
||||
if (loop_ids.empty())
|
||||
continue;
|
||||
|
||||
// Ticket: 113755
|
||||
// Note: we currently consider only the outermost loops for splitting
|
||||
// Splitting could also be done in a more general case, but the splitted loop and its parent must always
|
||||
// be in the same set of outer loops. Otherwise they won't be fused.
|
||||
const auto& loop_id = loop_ids.front();
|
||||
const auto loop = loop_manager->get_loop_info(loop_id);
|
||||
for (const auto& entry_point : loop->entry_points) {
|
||||
const auto& parent_port = entry_point.expr_port->get_port_connector_ptr()->get_source();
|
||||
const auto& parent_expr = parent_port.get_expr();
|
||||
const auto parent_loop_ids = parent_expr->get_loop_ids();
|
||||
if (parent_loop_ids.empty())
|
||||
continue;
|
||||
|
||||
const auto& parent_loop_id = parent_loop_ids.front();
|
||||
const auto parent_loop_port = loop_manager->get_loop_port_by_expr_port(parent_port, parent_loop_id);
|
||||
// We don't split loop which are not compatible with parent loop because such loops will not be fused
|
||||
if (!FuseLoops::loop_ports_are_compatible(loop_manager, loop_id, parent_loop_id))
|
||||
continue;
|
||||
|
||||
const auto parent_loop = loop_manager->get_loop_info(parent_loop_id);
|
||||
if (can_be_split(loop, parent_loop)) {
|
||||
loop_was_split = true;
|
||||
const bool split_parent = parent_loop->increment < loop->increment;
|
||||
const auto& loop_to_split = split_parent ? parent_loop : loop;
|
||||
const auto& loop_to_split_id = split_parent ? parent_loop_id : loop_id;
|
||||
const auto& loop_to_fuse = !split_parent ? parent_loop : loop;
|
||||
loop_to_split->work_amount = loop_to_fuse->increment;
|
||||
|
||||
LinearIR::constExprIt loop_begin_pos, loop_end_pos;
|
||||
LoopManager::get_loop_bounds(linear_ir,
|
||||
loop_to_split->entry_points,
|
||||
loop_to_split->exit_points,
|
||||
loop_begin_pos,
|
||||
loop_end_pos,
|
||||
loop_to_split_id);
|
||||
const auto split_loop_id = loop_manager->mark_loop(loop_begin_pos,
|
||||
loop_end_pos,
|
||||
loop_to_fuse->work_amount,
|
||||
loop_to_fuse->increment,
|
||||
loop_to_split->dim_idx,
|
||||
loop_to_split->entry_points,
|
||||
loop_to_split->exit_points);
|
||||
loop_manager->get_loop_info(split_loop_id)->outer_splited_loop = true;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
// Ticket: 113666
|
||||
// FuseLoops pass is explicitly run here in order to avoid unnecessary computations
|
||||
// in case if loops are not split but FuseLoops is registered in pass manager after SplitLoops
|
||||
if (loop_was_split)
|
||||
FuseLoops().run(linear_ir);
|
||||
return loop_was_split;
|
||||
}
|
||||
} // namespace pass
|
||||
} // namespace lowered
|
||||
} // namespace snippets
|
||||
} // namespace ov
|
@ -24,6 +24,7 @@
|
||||
#include "snippets/lowered/linear_ir.hpp"
|
||||
#include "snippets/lowered/pass/assign_registers.hpp"
|
||||
#include "snippets/lowered/pass/mark_loops.hpp"
|
||||
#include "snippets/lowered/pass/split_loops.hpp"
|
||||
#include "snippets/lowered/pass/fuse_loops.hpp"
|
||||
#include "snippets/lowered/pass/init_loops.hpp"
|
||||
#include "snippets/lowered/pass/insert_buffers.hpp"
|
||||
@ -507,6 +508,7 @@ void snippets::op::Subgraph::data_flow_transformations(ov::pass::Manager& pre_co
|
||||
}
|
||||
|
||||
void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& linear_ir,
|
||||
lowered::pass::PassPipeline& target_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_pipeline) {
|
||||
INTERNAL_OP_SCOPE(Subgraph);
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::op::control_flow_transformations")
|
||||
@ -514,10 +516,15 @@ void snippets::op::Subgraph::control_flow_transformations(lowered::LinearIR& lin
|
||||
const size_t vector_size = get_generator()->get_target_machine()->get_lanes();
|
||||
const int32_t buffer_allocation_rank = static_cast<int32_t>(linear_ir.get_config().m_loop_depth);
|
||||
|
||||
// Ticket: 113666
|
||||
// TODO: Make pass pipeline with backend passes more flexible
|
||||
target_markup_pipeline.run(linear_ir);
|
||||
|
||||
lowered::pass::PassPipeline common_pipeline;
|
||||
common_pipeline.register_pass<lowered::pass::MarkLoops>(vector_size);
|
||||
common_pipeline.register_pass<lowered::pass::SoftmaxDecomposition>(vector_size);
|
||||
common_pipeline.register_pass<lowered::pass::FuseLoops>();
|
||||
common_pipeline.register_pass<lowered::pass::SplitLoops>();
|
||||
common_pipeline.register_pass<lowered::pass::MoveResultOutOfLoop>();
|
||||
common_pipeline.register_pass<lowered::pass::InsertBuffers>(buffer_allocation_rank);
|
||||
common_pipeline.register_pass<lowered::pass::InsertLoadStore>(vector_size);
|
||||
@ -557,22 +564,24 @@ snippets::Schedule snippets::op::Subgraph::generate(const BlockedShapeVector& ou
|
||||
ov::pass::Manager& pre_common,
|
||||
ov::pass::Manager& post_common,
|
||||
ov::pass::Manager& post_precision,
|
||||
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_lowered_pipeline,
|
||||
const void* compile_params) {
|
||||
canonicalize(output_shapes, input_shapes);
|
||||
return generate(pre_common, post_common, post_precision, target_lowered_pipeline, compile_params);
|
||||
return generate(pre_common, post_common, post_precision, target_lowered_markup_pipeline, target_lowered_pipeline, compile_params);
|
||||
}
|
||||
|
||||
snippets::Schedule snippets::op::Subgraph::generate(const void* compile_params) {
|
||||
auto mngr = ov::pass::Manager();
|
||||
auto lowered = lowered::pass::PassPipeline();
|
||||
return generate(mngr, mngr, mngr, lowered, compile_params);
|
||||
return generate(mngr, mngr, mngr, lowered, lowered, compile_params);
|
||||
}
|
||||
|
||||
snippets::Schedule snippets::op::Subgraph::generate(
|
||||
ov::pass::Manager& pre_common,
|
||||
ov::pass::Manager& post_common,
|
||||
ov::pass::Manager& post_precision,
|
||||
lowered::pass::PassPipeline& target_lowered_markup_pipeline,
|
||||
lowered::pass::PassPipeline& target_lowered_pipeline,
|
||||
const void* compile_params) {
|
||||
INTERNAL_OP_SCOPE(Subgraph);
|
||||
@ -587,7 +596,7 @@ snippets::Schedule snippets::op::Subgraph::generate(
|
||||
lowering_config.m_loop_depth = tileRank;
|
||||
|
||||
lowered::LinearIR linear_ir = lowered::LinearIR(body_ptr(), lowering_config);
|
||||
control_flow_transformations(linear_ir, target_lowered_pipeline);
|
||||
control_flow_transformations(linear_ir, target_lowered_markup_pipeline, target_lowered_pipeline);
|
||||
|
||||
// actual code emission
|
||||
const auto& lowering_result = m_generator->generate(linear_ir, lowering_config, compile_params);
|
||||
|
@ -126,7 +126,8 @@ std::shared_ptr<ov::snippets::op::Subgraph> LoweringTests::getLoweredSubgraph(co
|
||||
}
|
||||
body_rt_info["PluginShapesOverride"] = new_shapes;
|
||||
subgraph->set_tile_rank(2);
|
||||
subgraph->generate(pre_dialect, post_precision, post_precision, lowered_pipeline);
|
||||
ov::snippets::lowered::pass::PassPipeline empty_pipeline;
|
||||
subgraph->generate(pre_dialect, post_precision, post_precision, empty_pipeline, lowered_pipeline);
|
||||
return subgraph;
|
||||
}
|
||||
|
||||
|
@ -701,8 +701,8 @@ void StoreConvertEmitter::emit_isa(const std::vector<size_t> &in, const std::vec
|
||||
void StoreConvertEmitter::emit_data() const {
|
||||
store_emitter->emit_data();
|
||||
}
|
||||
size_t BrgemmEmitter::getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const {
|
||||
return mIdx * 4 + kIdx * 2 + nIdx;
|
||||
size_t BrgemmEmitter::getBrgIdx(size_t kIdx, size_t nIdx) const {
|
||||
return kIdx * 2 + nIdx;
|
||||
}
|
||||
BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa,
|
||||
const std::shared_ptr<ov::Node>& node) : jit_emitter(h, isa, node) {
|
||||
@ -758,10 +758,8 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
|
||||
return std::distance(layout.begin(), std::find(layout.begin(), layout.end(), idx));
|
||||
};
|
||||
|
||||
m_M = C_shape[get_ordered_idx(C_layout, C_layout.size() - 2)];
|
||||
m_K = A_shape[get_ordered_idx(A_layout, A_layout.size() - 1)];
|
||||
m_M_blk = matmulOptimalM;
|
||||
m_M_tail = m_M % m_M_blk;
|
||||
m_M = brgemm_node->get_input_count(0);
|
||||
m_N = C_shape[get_ordered_idx(C_layout, C_layout.size() - 1)];
|
||||
|
||||
auto brg0Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(0));
|
||||
@ -780,34 +778,28 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
|
||||
: m_K;
|
||||
m_K_tail = m_K % m_K_blk;
|
||||
|
||||
size_t brg0BaseIdx = std::numeric_limits<size_t>::max();
|
||||
for (size_t m = 0; m < 2; m++) {
|
||||
for (size_t k = 0; k < 2; k++) {
|
||||
for (size_t n = 0; n < 2; n++) {
|
||||
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(m, k, n)];
|
||||
for (size_t k = 0; k < 2; k++) {
|
||||
for (size_t n = 0; n < 2; n++) {
|
||||
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(k, n)];
|
||||
|
||||
auto M_ = m ? m_M_tail
|
||||
: m_M < m_M_blk ? 0 : m_M_blk;
|
||||
auto N_ = n ? m_N_tail : m_N - m_N_tail;
|
||||
auto K_ = k ? m_K_tail : m_K - m_K_tail;
|
||||
auto beta = k && m_brgCtxs0[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f;
|
||||
auto M_ = m_M;
|
||||
auto N_ = n ? m_N_tail : m_N - m_N_tail;
|
||||
auto K_ = k ? m_K_tail : m_K - m_K_tail;
|
||||
auto beta = k && m_brgCtxs0[getBrgIdx(0, n)].K != 0 ? 1.0f : 0.0f;
|
||||
|
||||
brgemmCtx.M = M_;
|
||||
brgemmCtx.N = N_;
|
||||
brgemmCtx.K = K_;
|
||||
brgemmCtx.LDA = leading_dimensions[0];
|
||||
brgemmCtx.LDB = brgemm_node->is_with_data_repacking() ? rnd_up(m_N, m_N_blk) : leading_dimensions[1];
|
||||
brgemmCtx.LDC = leading_dimensions[2];
|
||||
brgemmCtx.dt_in0 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc));
|
||||
brgemmCtx.dt_in1 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg1Prc));
|
||||
brgemmCtx.beta = beta;
|
||||
brgemmCtx.M = M_;
|
||||
brgemmCtx.N = N_;
|
||||
brgemmCtx.K = K_;
|
||||
brgemmCtx.LDA = leading_dimensions[0];
|
||||
brgemmCtx.LDB = brgemm_node->is_with_data_repacking() ? rnd_up(m_N, m_N_blk) : leading_dimensions[1];
|
||||
brgemmCtx.LDC = leading_dimensions[2];
|
||||
brgemmCtx.dt_in0 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc));
|
||||
brgemmCtx.dt_in1 = static_cast<dnnl_data_type_t>(DnnlExtensionUtils::IEPrecisionToDataType(brg1Prc));
|
||||
brgemmCtx.beta = beta;
|
||||
|
||||
// don't create brgemm kernels for empty tiles
|
||||
if (M_ != 0 && K_ != 0 && N_ != 0) {
|
||||
if (brg0BaseIdx == std::numeric_limits<size_t>::max())
|
||||
brg0BaseIdx = getBrgIdx(m, k, n);
|
||||
initBrgemm(brgemmCtx, m_brgKernels0[getBrgIdx(m, k, n)], brgWithAMX);
|
||||
}
|
||||
// don't create brgemm kernels for empty tiles
|
||||
if (M_ != 0 && K_ != 0 && N_ != 0) {
|
||||
initBrgemm(brgemmCtx, m_brgKernels0[getBrgIdx(k, n)], brgWithAMX);
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -878,36 +870,31 @@ void BrgemmEmitter::emit_impl(const std::vector<size_t>& in,
|
||||
}
|
||||
Xbyak::Reg64 output_0(static_cast<int>(out[0]));
|
||||
|
||||
for (size_t mb = 0; mb < div_up(m_M, m_M_blk); mb++) {
|
||||
const bool is_M_tail = (m_M - mb * m_M_blk < m_M_blk);
|
||||
size_t brgIdx0 = getBrgIdx(0, 0);
|
||||
size_t K0_step0 = m_brgCtxs0[brgIdx0].K;
|
||||
size_t K0_step1 = m_brgCtxs0[brgIdx0].K * m_brgCtxs0[brgIdx0].LDB;
|
||||
size_t N0_step0 = m_brgCtxs0[brgIdx0].N * m_brg0VnniFactor;
|
||||
size_t N0_step1 = m_brgCtxs0[brgIdx0].N;
|
||||
for (size_t n = 0; n < 2; n++) {
|
||||
for (size_t k = 0; k < 2; k++) {
|
||||
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(k, n)];
|
||||
|
||||
size_t brgIdx0 = getBrgIdx(0, 0, 0);
|
||||
size_t K0_step0 = m_brgCtxs0[brgIdx0].K;
|
||||
size_t K0_step1 = m_brgCtxs0[brgIdx0].K * m_brgCtxs0[brgIdx0].LDB;
|
||||
size_t N0_step0 = m_brgCtxs0[brgIdx0].N * m_brg0VnniFactor;
|
||||
size_t N0_step1 = m_brgCtxs0[brgIdx0].N;
|
||||
for (size_t n = 0; n < 2; n++) {
|
||||
for (size_t k = 0; k < 2; k++) {
|
||||
size_t mIdx = is_M_tail ? 1 : 0;
|
||||
auto& brgemmCtx = m_brgCtxs0[getBrgIdx(mIdx, k, n)];
|
||||
if (brgemmCtx.K != 0 && brgemmCtx.N != 0) {
|
||||
const size_t in0_offset = m_load_offset_a + k * K0_step0 * io_data_size[0];
|
||||
const size_t in1_offset = m_load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1];
|
||||
const size_t in2_offset = m_load_offset_scratch + (m_with_comp ? n * N0_step1 * sizeof(int32_t) : 0);
|
||||
const size_t out0_offset = m_store_offset_c + n * N0_step1 * io_data_size[2];
|
||||
|
||||
if (brgemmCtx.K != 0 && brgemmCtx.N != 0) {
|
||||
const size_t in0_offset = m_load_offset_a + (k * K0_step0 + mb * m_M_blk * brgemmCtx.LDA) * io_data_size[0];
|
||||
const size_t in1_offset = m_load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1];
|
||||
const size_t in2_offset = m_load_offset_scratch + (m_with_comp ? n * N0_step1 * sizeof(int32_t) : 0);
|
||||
const size_t out0_offset = m_store_offset_c + (n * N0_step1 + mb * m_M_blk * brgemmCtx.LDC) * io_data_size[2];
|
||||
|
||||
emit_brgemm_kernel_call(m_brgKernels0[getBrgIdx(mIdx, k, n)].get(),
|
||||
brgemmCtx,
|
||||
input_0,
|
||||
input_1,
|
||||
input_2,
|
||||
output_0,
|
||||
in0_offset,
|
||||
in1_offset,
|
||||
in2_offset,
|
||||
out0_offset);
|
||||
}
|
||||
emit_brgemm_kernel_call(m_brgKernels0[getBrgIdx(k, n)].get(),
|
||||
brgemmCtx,
|
||||
input_0,
|
||||
input_1,
|
||||
input_2,
|
||||
output_0,
|
||||
in0_offset,
|
||||
in1_offset,
|
||||
in2_offset,
|
||||
out0_offset);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -353,7 +353,7 @@ private:
|
||||
float beta;
|
||||
};
|
||||
void initBrgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel, bool use_amx) const;
|
||||
size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const;
|
||||
size_t getBrgIdx(size_t kIdx, size_t nIdx) const;
|
||||
|
||||
void emit_brgemm_kernel_call(const dnnl::impl::cpu::x64::brgemm_kernel_t* brg_kernel, const brgemmCtx& ctx,
|
||||
Xbyak::Reg64 addr_A, Xbyak::Reg64 addr_B, Xbyak::Reg64 scratch, Xbyak::Reg64 addr_C,
|
||||
@ -362,11 +362,10 @@ private:
|
||||
static void kernel_execute(const dnnl::impl::cpu::x64::brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C, void *scratch, int with_comp);
|
||||
|
||||
static constexpr size_t BRGEMM_KERNELS_NUM = 8;
|
||||
static constexpr size_t matmulOptimalM = 32;
|
||||
brgemmCtx m_brgCtxs0[BRGEMM_KERNELS_NUM];
|
||||
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> m_brgKernels0[BRGEMM_KERNELS_NUM];
|
||||
|
||||
size_t m_M, m_M_blk, m_M_tail;
|
||||
size_t m_M;
|
||||
size_t m_K, m_K_blk, m_K_tail;
|
||||
size_t m_N, m_N_blk, m_N_tail;
|
||||
size_t m_brg0VnniFactor;
|
||||
|
@ -25,6 +25,7 @@
|
||||
#include "utils/cpu_utils.hpp"
|
||||
#include "emitters/x64/cpu_generator.hpp"
|
||||
#include "transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp"
|
||||
#include "transformations/snippets/x64/pass/lowered/brgemm_blocking.hpp"
|
||||
#include "transformations/snippets/x64/pass/mul_add_to_fma.hpp"
|
||||
#include "transformations/snippets/x64/pass/brgemm_to_brgemm_cpu.hpp"
|
||||
#include "transformations/snippets/x64/pass/remove_converts.hpp"
|
||||
@ -564,6 +565,9 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) {
|
||||
CPU_REGISTER_PASS_X64(post_precision, ov::intel_cpu::pass::RemoveConverts);
|
||||
CPU_REGISTER_PASS_X64(post_precision, ov::intel_cpu::pass::MulAddToFMA);
|
||||
|
||||
ov::snippets::lowered::pass::PassPipeline control_flow_markup_pipeline;
|
||||
CPU_REGISTER_PASS_X64(control_flow_markup_pipeline, ov::intel_cpu::pass::BrgemmBlocking);
|
||||
|
||||
ov::snippets::lowered::pass::PassPipeline control_flow_pipeline;
|
||||
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert);
|
||||
|
||||
@ -571,6 +575,7 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) {
|
||||
pre_dialect,
|
||||
post_dialect,
|
||||
post_precision,
|
||||
control_flow_markup_pipeline,
|
||||
control_flow_pipeline,
|
||||
reinterpret_cast<const void*>(jcp));
|
||||
}
|
||||
|
@ -0,0 +1,80 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "brgemm_blocking.hpp"
|
||||
|
||||
#include "openvino/pass/pattern/matcher.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "snippets/itt.hpp"
|
||||
#include "snippets/lowered/linear_ir.hpp"
|
||||
#include "snippets/lowered/loop_manager.hpp"
|
||||
#include "snippets/snippets_isa.hpp"
|
||||
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
|
||||
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace pass {
|
||||
using LoopManager = snippets::lowered::LinearIR::LoopManager;
|
||||
using LoopInfoPtr = LoopManager::LoopInfoPtr;
|
||||
using LoopPort = LoopManager::LoopPort;
|
||||
|
||||
BrgemmBlocking::BrgemmBlocking() : Pass() {}
|
||||
|
||||
bool BrgemmBlocking::run(snippets::lowered::LinearIR& linear_ir) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::BrgemmBlocking")
|
||||
if (linear_ir.empty())
|
||||
return false;
|
||||
|
||||
// Ticket: 113745
|
||||
// TODO: make the block size configurable
|
||||
const auto block_size = 32;
|
||||
const auto dim_idx = 1;
|
||||
|
||||
const auto& loop_manager = linear_ir.get_loop_manager();
|
||||
|
||||
auto blocking_loop_exists = [&](const ov::snippets::lowered::ExpressionPtr& expr,
|
||||
const std::shared_ptr<ov::intel_cpu::BrgemmCPU>& brgemm) {
|
||||
const auto& loop_ids = expr->get_loop_ids();
|
||||
for (const auto& id : loop_ids) {
|
||||
const auto loop = loop_manager->get_loop_info(id);
|
||||
if (loop->dim_idx == dim_idx) {
|
||||
OPENVINO_ASSERT(brgemm->get_input_count(0) == loop->increment,
|
||||
"Brgemm ", brgemm, " has input count (", brgemm->get_input_count(0),
|
||||
") which doesn't match the increment(", loop->increment, ") of loop by M");
|
||||
return true;
|
||||
}
|
||||
}
|
||||
return false;
|
||||
};
|
||||
|
||||
bool modified = false;
|
||||
for (auto expr_it = linear_ir.begin(); expr_it != linear_ir.end(); expr_it++) {
|
||||
const auto& expr = *expr_it;
|
||||
const auto brgemm = ov::as_type_ptr<ov::intel_cpu::BrgemmCPU>(expr->get_node());
|
||||
if (!brgemm || blocking_loop_exists(expr, brgemm))
|
||||
continue;
|
||||
|
||||
const auto& input_shape_0 = expr->get_input_port_descriptor(0)->get_shape();
|
||||
const auto& input_layout_0 = expr->get_input_port_descriptor(0)->get_layout();
|
||||
const auto& dim = *(input_layout_0.rbegin() + dim_idx);
|
||||
const auto& m = input_shape_0[dim];
|
||||
|
||||
brgemm->set_input_count(block_size);
|
||||
|
||||
const auto work_amount = m;
|
||||
const auto increment = block_size;
|
||||
|
||||
std::vector<LoopPort> entries{LoopPort(expr->get_input_port(0), true), LoopPort(expr->get_input_port(1), false)};
|
||||
if (brgemm->is_with_scratchpad())
|
||||
entries.emplace_back(expr->get_input_port(2), false);
|
||||
std::vector<LoopPort> exits{LoopPort(expr->get_output_port(0), true)};
|
||||
loop_manager->mark_loop(expr_it, std::next(expr_it), work_amount, increment, dim_idx, entries, exits);
|
||||
}
|
||||
|
||||
return modified;
|
||||
}
|
||||
} // namespace pass
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -0,0 +1,28 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "snippets/lowered/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace pass {
|
||||
|
||||
/**
|
||||
* @interface BrgemmBlocking
|
||||
* @brief Covers BrgemmCPU with blocking loop by M
|
||||
* @ingroup snippets
|
||||
*/
|
||||
|
||||
class BrgemmBlocking : public snippets::lowered::pass::Pass {
|
||||
public:
|
||||
OPENVINO_RTTI("BrgemmBlocking", "Pass")
|
||||
BrgemmBlocking();
|
||||
bool run(snippets::lowered::LinearIR& linear_ir) override;
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -21,7 +21,7 @@ namespace pass {
|
||||
class FuseLoadStoreConvert: public snippets::lowered::pass::Pass {
|
||||
public:
|
||||
FuseLoadStoreConvert() = default;
|
||||
OPENVINO_RTTI("FuseLoadStoreConvert", "LinearIRTransformation");
|
||||
OPENVINO_RTTI("FuseLoadStoreConvert", "Pass");
|
||||
bool run(snippets::lowered::LinearIR& linear_ir) override;
|
||||
|
||||
private:
|
||||
|
Loading…
Reference in New Issue
Block a user