[Snippets] Changed BrgemmCopyB shape inference (#19957)

This commit is contained in:
Vladislav Golubev 2023-10-12 14:34:53 +02:00 committed by GitHub
parent 518a879a83
commit 377e927149
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 197 additions and 182 deletions

View File

@ -39,14 +39,6 @@ public:
bool has_evaluate() const override { return false; }
class ShapeInfer : public IShapeInferSnippets {
protected:
std::vector<std::vector<size_t>> m_io_layouts;
public:
explicit ShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
protected:
ov::element::Type get_output_type() const;
std::vector<ov::PartialShape> get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const;

View File

@ -44,9 +44,10 @@ public:
size_t get_id() const { return m_id; }
Type get_type() const { return m_type; }
ov::Shape get_allocation_shape() const { return m_shape; }
int64_t get_offset() const { return m_offset; }
void set_id(size_t id) { m_id = id; }
const ov::Shape& get_allocation_shape() const { return m_shape; }
void set_allocation_shape(const ov::Shape& allocation_shape) { m_shape = allocation_shape; }
void set_offset(int64_t offset) { m_offset = offset; }
size_t get_byte_size() const;

View File

@ -61,5 +61,12 @@ public:
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
class BrgemmShapeInfer : public IShapeInferSnippets {
std::vector<std::vector<size_t>> m_io_layouts;
public:
explicit BrgemmShapeInfer(const std::shared_ptr<Node>& n);
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
};
} // namespace snippets
} // namespace ov

View File

@ -188,77 +188,6 @@ ov::PartialShape Brgemm::get_output_partial_shape(const std::vector<ov::PartialS
}
return output_shape;
}
Brgemm::ShapeInfer::ShapeInfer(const std::shared_ptr<Node>& n) {
for (const auto& in : n->inputs()) {
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in);
m_io_layouts.push_back(port->get_layout());
}
m_io_layouts.push_back(get_output_layout(n));
}
IShapeInferSnippets::Result Brgemm::ShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 2, "BRGEMM expects 2 input shapes for shape inference");
// Todo: Ideally we should use the layout stored in PortDescriptors. Can we do it?
const auto& arg0_shape = snippets::utils::get_planar_vdims(input_shapes[0].get(), m_io_layouts[0]);
const auto& arg1_shape = snippets::utils::get_planar_vdims(input_shapes[1].get(), m_io_layouts[1]);
size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size();
// temporary shapes to calculate output shape
VectorDims arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape);
// one-dimensional tensors unsqueezing is applied to each input independently.
if (arg0_rank == 1) {
// If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector)
// by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape.
// For example {S} will be reshaped to {1, S}.
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1);
arg0_rank = arg0_shape_tmp.size();
}
if (arg1_rank == 1) {
// If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector)
// by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape.
// For example {S} will be reshaped to {S, 1}.
arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1);
arg1_rank = arg1_shape_tmp.size();
}
// add 1 to begin to align shape ranks if needed
if (arg0_rank < arg1_rank)
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1);
else if (arg0_rank > arg1_rank)
arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1);
size_t max_rank = arg0_shape_tmp.size();
VectorDims output_shape(max_rank);
for (size_t i = 0; i < max_rank - 2; ++i) {
if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) {
output_shape[i] = arg0_shape_tmp[i];
} else {
if (arg0_shape_tmp[i] == 1 || arg0_shape_tmp[i] == DYNAMIC_DIMENSION)
output_shape[i] = arg1_shape_tmp[i];
else if (arg1_shape_tmp[i] == 1 || arg1_shape_tmp[i] == DYNAMIC_DIMENSION)
output_shape[i] = arg0_shape_tmp[i];
else
OPENVINO_THROW("Incompatible Brgemm batch dimension");
}
}
output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M
output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N
// removing the temporary axes from originally 1D tensors.
if (arg0_shape.size() == 1) {
output_shape.erase(output_shape.begin() + output_shape.size() - 2);
}
if (arg1_shape.size() == 1) {
output_shape.erase(output_shape.begin() + output_shape.size() - 1);
}
output_shape = snippets::utils::get_planar_vdims(output_shape, m_io_layouts[2]);
return {{output_shape}, snippets::ShapeInferStatus::success};
}
} // namespace op
} // namespace snippets
} // namespace ov

View File

@ -3,6 +3,7 @@
//
#include "snippets/shape_inference/shape_infer_instances.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/utils.hpp"
#include "openvino/op/select.hpp"
namespace ov {
namespace snippets {
@ -160,5 +161,76 @@ Result HorizonOpShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes
return {{output_shapes}, ShapeInferStatus::success};
}
BrgemmShapeInfer::BrgemmShapeInfer(const std::shared_ptr<Node>& n) {
for (const auto& in : n->inputs()) {
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in);
m_io_layouts.push_back(port->get_layout());
}
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(n->output(0));
m_io_layouts.push_back(port->get_layout());
}
Result BrgemmShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 2 || input_shapes.size() == 3, "BRGEMM expects 2 or 3 input shapes for shape inference");
// Todo: Ideally we should use the layout stored in PortDescriptors. Can we do it?
const auto& arg0_shape = ov::snippets::utils::get_planar_vdims(input_shapes[0].get(), m_io_layouts[0]);
const auto& arg1_shape = ov::snippets::utils::get_planar_vdims(input_shapes[1].get(), m_io_layouts[1]);
size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size();
// temporary shapes to calculate output shape
VectorDims arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape);
// one-dimensional tensors unsqueezing is applied to each input independently.
if (arg0_rank == 1) {
// If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector)
// by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape.
// For example {S} will be reshaped to {1, S}.
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1);
arg0_rank = arg0_shape_tmp.size();
}
if (arg1_rank == 1) {
// If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector)
// by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape.
// For example {S} will be reshaped to {S, 1}.
arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1);
arg1_rank = arg1_shape_tmp.size();
}
// add 1 to begin to align shape ranks if needed
if (arg0_rank < arg1_rank)
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1);
else if (arg0_rank > arg1_rank)
arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1);
size_t max_rank = arg0_shape_tmp.size();
VectorDims output_shape(max_rank);
for (size_t i = 0; i < max_rank - 2; ++i) {
if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) {
output_shape[i] = arg0_shape_tmp[i];
} else {
if (arg0_shape_tmp[i] == 1 || arg0_shape_tmp[i] == DYNAMIC_DIMENSION)
output_shape[i] = arg1_shape_tmp[i];
else if (arg1_shape_tmp[i] == 1 || arg1_shape_tmp[i] == DYNAMIC_DIMENSION)
output_shape[i] = arg0_shape_tmp[i];
else
OPENVINO_THROW("Incompatible Brgemm batch dimension");
}
}
output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M
output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N
// removing the temporary axes from originally 1D tensors.
if (arg0_shape.size() == 1) {
output_shape.erase(output_shape.begin() + output_shape.size() - 2);
}
if (arg1_shape.size() == 1) {
output_shape.erase(output_shape.begin() + output_shape.size() - 1);
}
output_shape = ov::snippets::utils::get_planar_vdims(output_shape, m_io_layouts.back());
return {{output_shape}, snippets::ShapeInferStatus::success};
}
} // namespace snippets
} // namespace ov

View File

@ -58,11 +58,11 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry
SHAPE_INFER_PREDEFINED(op::Kernel, EmptyShapeInfer),
SHAPE_INFER_PREDEFINED(op::Nop, EmptyShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(opset1::Select, SelectShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::Brgemm, BrgemmShapeInfer),
// Note that Result has no output PortConnectors, so the shape must be empty
SHAPE_INFER_PREDEFINED(ov::op::v0::Result, EmptyShapeInfer),
//
SHAPE_INFER_OP_SPECIFIC(op::LoadReshape),
SHAPE_INFER_OP_SPECIFIC(op::Brgemm),
SHAPE_INFER_OP_SPECIFIC(op::BroadcastLoad),
SHAPE_INFER_OP_SPECIFIC(op::BroadcastMove),
};

View File

@ -22,6 +22,7 @@
#include "snippets/pass/matmul_to_brgemm.hpp"
#include "utils/cpu_utils.hpp"
#include "emitters/x64/cpu_generator.hpp"
#include "transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.hpp"
#include "transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp"
#include "transformations/snippets/x64/pass/lowered/brgemm_blocking.hpp"
#include "transformations/snippets/x64/pass/mul_add_to_fma.hpp"
@ -618,6 +619,7 @@ void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp)
ov::snippets::lowered::pass::PassPipeline control_flow_pipeline;
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert)
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape);
// Note: we need to pass valid shapeInfer factory to generate, so it can be used in OptimizeDomain pass
// in all other cases nGraph shape inference will be used until ticket # 113209 (PR 18563) is merged
schedule = snippet_for_generation->generate(backend_passes,

View File

@ -57,39 +57,34 @@ void BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vector<size_t
// During ctor call, BrgemmCopyB doesn't know his port descriptors.
// So we use port descs from source inputs
const auto element_type = get_input_element_type(0);
const auto pshape = snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_input);
validate(pshape, element_type);
const auto& pshape = get_input_partial_shape(0);
// The data always store in planar shape after repacking
const auto planar_pshape = snippets::utils::get_planar_pshape(pshape, layout_input);
// data repacking output
set_output_type(0, element_type, planar_pshape);
// If compensations are needed, they are provided in 2nd output (which is used in BrgemmCPU)
if (is_with_compensations()) {
set_output_type(1, ov::element::f32, planar_pshape);
}
validate(planar_pshape, element_type);
}
void BrgemmCopyB::validate_and_infer_types() {
INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types);
const auto port = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0));
const auto shape = ov::Shape(port->get_shape());
const auto& element_type = get_input_element_type(0);
const auto& pshape = snippets::utils::get_planar_pshape(input(0));
validate(pshape, element_type);
const auto& planar_pshape = snippets::utils::get_planar_pshape(shape, port->get_layout());
set_output_type(0, element_type, planar_pshape);
if (is_with_compensations()) {
set_output_type(1, ov::element::f32, planar_pshape);
}
validate(planar_pshape, element_type);
}
void BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov::element::Type& element_type) {
NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8),
"BrgemmCopyB doesn't support element type" + element_type.get_type_name());
if (pshape.is_dynamic()) {
set_output_type(0, element_type, ov::PartialShape {ov::Dimension::dynamic()});
if (is_with_compensations()) {
set_output_type(1, ov::element::f32, ov::PartialShape {ov::Dimension::dynamic()});
}
return;
}
const auto shape = pshape.get_shape();
const auto N = *shape.rbegin();
const auto K = *(shape.rbegin() + 1);
set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, m_brgemmVNNIFactor)),
ov::Dimension(rnd_up(N, m_N_blk))});
if (is_with_compensations()) {
set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, m_N_blk))});
}
void BrgemmCopyB::validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type) {
OPENVINO_ASSERT(one_of(element_type, element::bf16, element::i8),
"BrgemmCopyB doesn't support element type" + element_type.get_type_name());
}
void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n) {
@ -98,6 +93,17 @@ void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k,
m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape.rbegin();
}
ov::Shape intel_cpu::BrgemmCopyB::get_data_repacking_shape(const ov::snippets::VectorDims& planar_dims) const {
const auto& N = *planar_dims.rbegin();
const auto& K = *(planar_dims.rbegin() + 1);
return ov::Shape{rnd_up(K, m_brgemmVNNIFactor), rnd_up(N, m_N_blk)};
}
ov::Shape intel_cpu::BrgemmCopyB::get_compensation_shape(const ov::snippets::VectorDims& planar_dims) const {
const auto& N = *planar_dims.rbegin();
return ov::Shape{rnd_up(N, m_N_blk)};
}
std::shared_ptr<Node> intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const {
INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs);
check_new_args_count(this, new_args);
@ -120,29 +126,13 @@ BrgemmCopyB::ShapeInfer::ShapeInfer(const std::shared_ptr<ov::Node>& n) {
OPENVINO_ASSERT(brg_copyb, "Got invalid node in BrgemmCopyB::ShapeInfer");
m_layout = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(n->input(0))->get_layout();
m_num_outs = brg_copyb->get_output_size();
m_N_blk = brg_copyb->get_n_block_size();
m_brgemmVNNIFactor = brg_copyb->m_brgemmVNNIFactor;
}
snippets::IShapeInferSnippets::Result BrgemmCopyB::ShapeInfer::infer(const std::vector<snippets::VectorDimsRef>& input_shapes) {
ov::snippets::IShapeInferSnippets::Result BrgemmCopyB::ShapeInfer::infer(const std::vector<ov::snippets::VectorDimsRef>& input_shapes) {
OPENVINO_ASSERT(input_shapes.size() == 1, "Got unexpected number of input shapes");
const auto& old_shape = input_shapes[0].get();
snippets::VectorDims planar_shape;
planar_shape.reserve(old_shape.size());
for (const auto idx : m_layout)
planar_shape.push_back(old_shape[idx]);
const auto N = *planar_shape.rbegin();
const auto K = *(planar_shape.rbegin() + 1);
OPENVINO_ASSERT(N != DYNAMIC_DIMENSION && K != DYNAMIC_DIMENSION,
"BrgemmCopyB shape infer got dynamic N or K dimension, which is not supported");
std::vector<snippets::VectorDims> new_shapes(m_num_outs);
new_shapes[0].push_back(rnd_up(K, m_brgemmVNNIFactor));
new_shapes[0].push_back(rnd_up(N, m_N_blk));
if (m_num_outs == 2) {
new_shapes[1].push_back(rnd_up(N, m_N_blk));
}
return {new_shapes, snippets::ShapeInferStatus::success};
const auto planar_shape = ov::snippets::utils::get_planar_vdims(input_shapes[0].get(), m_layout);
std::vector<ov::snippets::VectorDims> new_shapes(m_num_outs, planar_shape);
return {new_shapes, ov::snippets::ShapeInferStatus::success};
}
} // namespace intel_cpu

View File

@ -5,6 +5,7 @@
#pragma once
#include "snippets/op/memory_access.hpp"
#include "snippets/shape_types.hpp"
#include <snippets/shape_inference/shape_inference.hpp>
namespace ov {
@ -43,7 +44,11 @@ public:
void set_k_block_size(size_t block_size) { m_K_blk = block_size; }
void set_n_block_size(size_t block_size) { m_N_blk = block_size; }
ov::Shape get_data_repacking_shape(const ov::snippets::VectorDims& planar_dims) const;
ov::Shape get_compensation_shape(const ov::snippets::VectorDims& planar_dims) const;
Type get_type() const { return m_type; }
size_t get_brgemm_vnni_factor() const { return m_brgemmVNNIFactor; }
element::Type get_src_element_type() const { return m_src_type; }
bool is_with_compensations() const { return m_type == Type::WithCompensations; }
@ -55,8 +60,6 @@ public:
class ShapeInfer : public snippets::IShapeInferSnippets {
std::vector<size_t> m_layout{};
size_t m_num_outs = 1;
size_t m_N_blk = 64;
size_t m_brgemmVNNIFactor = 1;
public:
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
Result infer(const std::vector<snippets::VectorDimsRef>& input_shapes) override;
@ -64,7 +67,7 @@ public:
private:
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_input = {});
void validate(const ov::PartialShape& pshape, const ov::element::Type& element_type);
void validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type);
void compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n);
Type m_type = Type::OnlyRepacking;

View File

@ -114,21 +114,9 @@ void BrgemmCPU::validate_with_scratchpad(const ov::Shape& shape_b) const {
// Additional check for 3rd input
if (one_of(m_type, Type::WithCompensations, Type::AMX)) {
const auto& pshape = get_input_partial_shape(2);
NGRAPH_CHECK(pshape.is_static(), "BRGEMM Scratch must have static shape");
const auto shape = pshape.to_shape();
const auto type = get_input_element_type(2);
OPENVINO_ASSERT(pshape.is_static(), "BRGEMM Scratch must have static shape");
if (is_with_compensations()) {
const auto expected_type = ov::element::f32;
NGRAPH_CHECK(expected_type == type, "BRGEMM Scratch with compensations must have FP32 element type");
const auto N = *shape_b.rbegin();
// If N block size is not set, there is no meaning in validating the scratchpad shape
if (m_N_blk != N) {
const auto expected_shape = ov::Shape{rnd_up(N, m_N_blk)};
NGRAPH_CHECK(expected_shape == shape, "BRGEMM Scratch with compensations must have shape {rnd_up(N, m_N_blk)}");
}
} else {
NGRAPH_CHECK(ov::shape_size(shape) == SCRATCH_BYTE_SIZE && type == ov::element::u8,
"BRGEMM Scratch for space workplace must be static, have U8 element type and size equal to " + std::to_string(SCRATCH_BYTE_SIZE));
OPENVINO_ASSERT(get_input_element_type(2) == ov::element::f32, "BRGEMM Scratch with compensations must have FP32 element type");
}
}
}
@ -181,13 +169,5 @@ size_t BrgemmCPU::get_offset_scratch() const {
return get_input_offset(2);
}
BrgemmCPU::ShapeInfer::ShapeInfer(const std::shared_ptr<ov::Node>& n) : Brgemm::ShapeInfer(n) {
const auto& brg = ov::as_type_ptr<BrgemmCPU>(n);
OPENVINO_ASSERT(brg, "Got invalid node in BrgemmCPU::ShapeInfer");
const auto brgemm_copy = brg->is_with_data_repacking() ? brg->get_brgemm_copy() : nullptr;
if (brgemm_copy)
m_io_layouts[1] = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(brgemm_copy->input(0))->get_layout();
}
} // namespace intel_cpu
} // namespace ov

View File

@ -69,12 +69,6 @@ public:
constexpr static size_t SCRATCH_BYTE_SIZE = 32 * 1024;
class ShapeInfer : public Brgemm::ShapeInfer {
public:
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
};
private:
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c);
void compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n);

View File

@ -0,0 +1,38 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "snippets/itt.hpp"
#include "set_brgemm_copy_b_buffers_shape.hpp"
#include "snippets/snippets_isa.hpp"
#include "snippets/utils.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
bool ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape::run(snippets::lowered::LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::SetBrgemmCopyBBuffersShape")
auto get_buffer_from_output = [](const snippets::lowered::ExpressionPtr& expr, const size_t out_idx) {
const auto& consumers = expr->get_output_port_connector(out_idx)->get_consumers();
OPENVINO_ASSERT(consumers.size() == 1, "BrgemmCopyB must have only 1 consumer");
const auto buffer = ov::as_type_ptr<ov::snippets::op::Buffer>(consumers.begin()->get_expr()->get_node());
OPENVINO_ASSERT(buffer, "BrgemmCopyB consumer must be Buffer");
return buffer;
};
bool modified = false;
for (const auto& expr : linear_ir) {
if (auto copy_b = ov::as_type_ptr<ov::intel_cpu::BrgemmCopyB>(expr->get_node())) {
const auto buffer = get_buffer_from_output(expr, 0);
const auto& out_desc = expr->get_output_port_descriptor(0);
buffer->set_allocation_shape(copy_b->get_data_repacking_shape(out_desc->get_shape()));
if (copy_b->is_with_compensations()) {
const auto compensations_buffer = get_buffer_from_output(expr, 1);
compensations_buffer->set_allocation_shape(copy_b->get_compensation_shape(out_desc->get_shape()));
}
modified = true;
}
}
return modified;
}

View File

@ -0,0 +1,27 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "snippets/lowered/pass/pass.hpp"
namespace ov {
namespace intel_cpu {
namespace pass {
/**
* @interface SetBrgemmCopyBBuffersShape
* @brief Sets the allocation shape for the Buffers after BrgemmCopyB node using BrgemmCopyB parameters
* @ingroup snippets
*/
class SetBrgemmCopyBBuffersShape: public snippets::lowered::pass::Pass {
public:
SetBrgemmCopyBBuffersShape() = default;
OPENVINO_RTTI("SetBrgemmCopyBBuffersShape", "Pass");
bool run(snippets::lowered::LinearIR& linear_ir) override;
};
} // namespace pass
} // namespace intel_cpu
} // namespace ov

View File

@ -22,18 +22,6 @@
namespace ov {
namespace intel_cpu {
using namespace snippets::lowered;
namespace {
template <typename T>
void change_desc_shape(const T& port) {
const auto desc = PortDescriptorUtils::get_port_descriptor_ptr(port);
const auto& shape = port.get_shape();
if (desc->get_shape() != shape) {
desc->set_shape(shape);
}
}
} // namespace
pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() {
MATCHER_SCOPE(SetBrgemmCPUBlockingParams);
@ -73,7 +61,7 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() {
const bool isAMXSupported = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx);
const auto precision = brgemm_copy_b->get_src_element_type();
const auto brgemmVNNIFactor = 4 / precision.size();
const auto brgemmVNNIFactor = brgemm_copy_b->get_brgemm_vnni_factor();
const bool use_amx = isAMXSupported && precision != ov::element::f32 && (K % brgemmVNNIFactor == 0) && (N % brgemmVNNIFactor == 0);
const size_t copy_b_block_size_k = use_amx ? brgemm_block_size_k : K;
@ -81,18 +69,8 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() {
brgemm_copy_b->set_k_block_size(copy_b_block_size_k);
brgemm_copy_b->set_n_block_size(copy_b_block_size_n);
// since N block size affects output shapes, the validation must be called explicitly right after the block size changing
brgemm_copy_b->validate_and_infer_types();
change_desc_shape(brgemm_copy_b->output(0));
if (brgemm_copy_b->is_with_compensations())
change_desc_shape(brgemm_copy_b->output(1));
}
brgemm->validate_and_infer_types();
change_desc_shape(brgemm->input(1));
if (brgemm->is_with_scratchpad())
change_desc_shape(brgemm->input(2));
return false;
};

View File

@ -28,6 +28,8 @@ ShapeInferPtr CPUShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<InferType>();} }
#define SHAPE_INFER_OP_SPECIFIC(OP) \
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<OP::ShapeInfer>(n);} }
#define SHAPE_INFER_OP_SPECIFIC_EXTERNAL(OP, InferType) \
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<InferType>(n);} }
const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::specific_ops_registry {
SHAPE_INFER_PREDEFINED(ov::intel_cpu::FusedMulAdd, NumpyBroadcastShapeInfer),
@ -36,9 +38,9 @@ const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::spec
SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertTruncation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertSaturation, PassThroughShapeInfer),
SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertTruncation, PassThroughShapeInfer),
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(ov::intel_cpu::BrgemmCPU, BrgemmShapeInfer),
//
SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCopyB),
SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCPU),
};
#undef SHAPE_INFER_OP_SPECIFIC
#undef SHAPE_INFER_PREDEFINED