[Snippets] init_ptr_increments: removed useless computations for non incremented ports (#19224)

This commit is contained in:
Vladislav Golubev 2023-08-17 13:52:55 +02:00 committed by GitHub
parent 5d8abfe41f
commit 3e0f529700
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -32,34 +32,38 @@ InitLoops::InitLoops() : Pass() {}
void InitLoops::init_ptr_increments(std::vector<LoopPort>& loop_inputs, std::vector<LoopPort>& loop_outputs, size_t work_amount, size_t dim_idx) {
for (auto& loop_input : loop_inputs) {
const auto& port = loop_input.expr_port;
const auto source = *port->get_connected_ports().begin();
const auto loop_ids = port->get_expr()->get_loop_ids();
const auto& layout = port->get_descriptor_ptr()->get_layout();
const auto& shape = port->get_descriptor_ptr()->get_shape();
const auto& dim = *(layout.rbegin() + dim_idx);
loop_input.ptr_increment = 0;
// If relevant dim is not broadcasted, then ptr_increment is the dim stride in the new layout
if (loop_input.is_incremented && !(shape[dim] == 1 && work_amount != 1)) {
loop_input.ptr_increment = get_dim_stride(dim, source.get_descriptor_ptr()->get_layout(), shape);
if (loop_input.is_incremented) {
const auto& port = loop_input.expr_port;
const auto source = *port->get_connected_ports().begin();
const auto loop_ids = port->get_expr()->get_loop_ids();
const auto& layout = port->get_descriptor_ptr()->get_layout();
const auto& shape = port->get_descriptor_ptr()->get_shape();
const auto& dim = *(layout.rbegin() + dim_idx);
// If relevant dim is not broadcasted, then ptr_increment is the dim stride in the new layout
if (!(shape[dim] == 1 && work_amount != 1)) {
loop_input.ptr_increment = get_dim_stride(dim, source.get_descriptor_ptr()->get_layout(), shape);
}
}
}
for (auto& loop_output : loop_outputs) {
const auto& port = loop_output.expr_port;
const auto loop_ids = port->get_expr()->get_loop_ids();
const auto& layout = port->get_descriptor_ptr()->get_layout();
const auto& shape = port->get_descriptor_ptr()->get_shape();
const auto& dim = *(layout.rbegin() + dim_idx);
// Ticket: 113106
// WA: the current logic doesn't support the case with transposed output shape for brgemm layer
// but for all existing cases planar layout can be used
std::vector<size_t> planar(layout.size());
std::iota(planar.begin(), planar.end(), 0);
loop_output.ptr_increment = 0;
// If relevant dim is not broadcasted, then ptr_increment is the dim stride in the new layout
if (loop_output.is_incremented && !(shape[dim] == 1 && work_amount != 1)) {
loop_output.ptr_increment = get_dim_stride(dim, planar, shape);
if (loop_output.is_incremented) {
const auto& port = loop_output.expr_port;
const auto loop_ids = port->get_expr()->get_loop_ids();
const auto& layout = port->get_descriptor_ptr()->get_layout();
const auto& shape = port->get_descriptor_ptr()->get_shape();
const auto& dim = *(layout.rbegin() + dim_idx);
// Ticket: 113106
// WA: the current logic doesn't support the case with transposed output shape for brgemm layer
// but for all existing cases planar layout can be used
std::vector<size_t> planar(layout.size());
std::iota(planar.begin(), planar.end(), 0);
// If relevant dim is not broadcasted, then ptr_increment is the dim stride in the new layout
if (!(shape[dim] == 1 && work_amount != 1)) {
loop_output.ptr_increment = get_dim_stride(dim, planar, shape);
}
}
}
}