[Snippets] Changed BrgemmCopyB shape inference (#19957)
This commit is contained in:
parent
518a879a83
commit
377e927149
@ -39,14 +39,6 @@ public:
|
||||
|
||||
bool has_evaluate() const override { return false; }
|
||||
|
||||
class ShapeInfer : public IShapeInferSnippets {
|
||||
protected:
|
||||
std::vector<std::vector<size_t>> m_io_layouts;
|
||||
public:
|
||||
explicit ShapeInfer(const std::shared_ptr<Node>& n);
|
||||
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
|
||||
};
|
||||
|
||||
protected:
|
||||
ov::element::Type get_output_type() const;
|
||||
std::vector<ov::PartialShape> get_planar_input_shapes(const std::vector<ov::Input<ov::Node>>& inputs) const;
|
||||
|
@ -44,9 +44,10 @@ public:
|
||||
|
||||
size_t get_id() const { return m_id; }
|
||||
Type get_type() const { return m_type; }
|
||||
ov::Shape get_allocation_shape() const { return m_shape; }
|
||||
int64_t get_offset() const { return m_offset; }
|
||||
void set_id(size_t id) { m_id = id; }
|
||||
const ov::Shape& get_allocation_shape() const { return m_shape; }
|
||||
void set_allocation_shape(const ov::Shape& allocation_shape) { m_shape = allocation_shape; }
|
||||
void set_offset(int64_t offset) { m_offset = offset; }
|
||||
size_t get_byte_size() const;
|
||||
|
||||
|
@ -61,5 +61,12 @@ public:
|
||||
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
|
||||
};
|
||||
|
||||
class BrgemmShapeInfer : public IShapeInferSnippets {
|
||||
std::vector<std::vector<size_t>> m_io_layouts;
|
||||
public:
|
||||
explicit BrgemmShapeInfer(const std::shared_ptr<Node>& n);
|
||||
Result infer(const std::vector<VectorDimsRef>& input_shapes) override;
|
||||
};
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace ov
|
||||
|
@ -188,77 +188,6 @@ ov::PartialShape Brgemm::get_output_partial_shape(const std::vector<ov::PartialS
|
||||
}
|
||||
return output_shape;
|
||||
}
|
||||
|
||||
Brgemm::ShapeInfer::ShapeInfer(const std::shared_ptr<Node>& n) {
|
||||
for (const auto& in : n->inputs()) {
|
||||
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in);
|
||||
m_io_layouts.push_back(port->get_layout());
|
||||
}
|
||||
m_io_layouts.push_back(get_output_layout(n));
|
||||
}
|
||||
|
||||
IShapeInferSnippets::Result Brgemm::ShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
|
||||
OPENVINO_ASSERT(input_shapes.size() == 2, "BRGEMM expects 2 input shapes for shape inference");
|
||||
|
||||
// Todo: Ideally we should use the layout stored in PortDescriptors. Can we do it?
|
||||
const auto& arg0_shape = snippets::utils::get_planar_vdims(input_shapes[0].get(), m_io_layouts[0]);
|
||||
const auto& arg1_shape = snippets::utils::get_planar_vdims(input_shapes[1].get(), m_io_layouts[1]);
|
||||
|
||||
size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size();
|
||||
|
||||
// temporary shapes to calculate output shape
|
||||
VectorDims arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape);
|
||||
|
||||
// one-dimensional tensors unsqueezing is applied to each input independently.
|
||||
if (arg0_rank == 1) {
|
||||
// If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector)
|
||||
// by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape.
|
||||
// For example {S} will be reshaped to {1, S}.
|
||||
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1);
|
||||
arg0_rank = arg0_shape_tmp.size();
|
||||
}
|
||||
if (arg1_rank == 1) {
|
||||
// If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector)
|
||||
// by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape.
|
||||
// For example {S} will be reshaped to {S, 1}.
|
||||
arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1);
|
||||
arg1_rank = arg1_shape_tmp.size();
|
||||
}
|
||||
|
||||
// add 1 to begin to align shape ranks if needed
|
||||
if (arg0_rank < arg1_rank)
|
||||
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1);
|
||||
else if (arg0_rank > arg1_rank)
|
||||
arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1);
|
||||
|
||||
size_t max_rank = arg0_shape_tmp.size();
|
||||
VectorDims output_shape(max_rank);
|
||||
for (size_t i = 0; i < max_rank - 2; ++i) {
|
||||
if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) {
|
||||
output_shape[i] = arg0_shape_tmp[i];
|
||||
} else {
|
||||
if (arg0_shape_tmp[i] == 1 || arg0_shape_tmp[i] == DYNAMIC_DIMENSION)
|
||||
output_shape[i] = arg1_shape_tmp[i];
|
||||
else if (arg1_shape_tmp[i] == 1 || arg1_shape_tmp[i] == DYNAMIC_DIMENSION)
|
||||
output_shape[i] = arg0_shape_tmp[i];
|
||||
else
|
||||
OPENVINO_THROW("Incompatible Brgemm batch dimension");
|
||||
}
|
||||
}
|
||||
output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M
|
||||
output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N
|
||||
|
||||
// removing the temporary axes from originally 1D tensors.
|
||||
if (arg0_shape.size() == 1) {
|
||||
output_shape.erase(output_shape.begin() + output_shape.size() - 2);
|
||||
}
|
||||
if (arg1_shape.size() == 1) {
|
||||
output_shape.erase(output_shape.begin() + output_shape.size() - 1);
|
||||
}
|
||||
output_shape = snippets::utils::get_planar_vdims(output_shape, m_io_layouts[2]);
|
||||
return {{output_shape}, snippets::ShapeInferStatus::success};
|
||||
}
|
||||
|
||||
} // namespace op
|
||||
} // namespace snippets
|
||||
} // namespace ov
|
||||
|
@ -3,6 +3,7 @@
|
||||
//
|
||||
#include "snippets/shape_inference/shape_infer_instances.hpp"
|
||||
#include "snippets/snippets_isa.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
#include "openvino/op/select.hpp"
|
||||
namespace ov {
|
||||
namespace snippets {
|
||||
@ -160,5 +161,76 @@ Result HorizonOpShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes
|
||||
return {{output_shapes}, ShapeInferStatus::success};
|
||||
}
|
||||
|
||||
BrgemmShapeInfer::BrgemmShapeInfer(const std::shared_ptr<Node>& n) {
|
||||
for (const auto& in : n->inputs()) {
|
||||
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in);
|
||||
m_io_layouts.push_back(port->get_layout());
|
||||
}
|
||||
const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(n->output(0));
|
||||
m_io_layouts.push_back(port->get_layout());
|
||||
}
|
||||
|
||||
Result BrgemmShapeInfer::infer(const std::vector<VectorDimsRef>& input_shapes) {
|
||||
OPENVINO_ASSERT(input_shapes.size() == 2 || input_shapes.size() == 3, "BRGEMM expects 2 or 3 input shapes for shape inference");
|
||||
|
||||
// Todo: Ideally we should use the layout stored in PortDescriptors. Can we do it?
|
||||
const auto& arg0_shape = ov::snippets::utils::get_planar_vdims(input_shapes[0].get(), m_io_layouts[0]);
|
||||
const auto& arg1_shape = ov::snippets::utils::get_planar_vdims(input_shapes[1].get(), m_io_layouts[1]);
|
||||
|
||||
size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size();
|
||||
|
||||
// temporary shapes to calculate output shape
|
||||
VectorDims arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape);
|
||||
|
||||
// one-dimensional tensors unsqueezing is applied to each input independently.
|
||||
if (arg0_rank == 1) {
|
||||
// If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector)
|
||||
// by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape.
|
||||
// For example {S} will be reshaped to {1, S}.
|
||||
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1);
|
||||
arg0_rank = arg0_shape_tmp.size();
|
||||
}
|
||||
if (arg1_rank == 1) {
|
||||
// If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector)
|
||||
// by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape.
|
||||
// For example {S} will be reshaped to {S, 1}.
|
||||
arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1);
|
||||
arg1_rank = arg1_shape_tmp.size();
|
||||
}
|
||||
|
||||
// add 1 to begin to align shape ranks if needed
|
||||
if (arg0_rank < arg1_rank)
|
||||
arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1);
|
||||
else if (arg0_rank > arg1_rank)
|
||||
arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1);
|
||||
|
||||
size_t max_rank = arg0_shape_tmp.size();
|
||||
VectorDims output_shape(max_rank);
|
||||
for (size_t i = 0; i < max_rank - 2; ++i) {
|
||||
if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) {
|
||||
output_shape[i] = arg0_shape_tmp[i];
|
||||
} else {
|
||||
if (arg0_shape_tmp[i] == 1 || arg0_shape_tmp[i] == DYNAMIC_DIMENSION)
|
||||
output_shape[i] = arg1_shape_tmp[i];
|
||||
else if (arg1_shape_tmp[i] == 1 || arg1_shape_tmp[i] == DYNAMIC_DIMENSION)
|
||||
output_shape[i] = arg0_shape_tmp[i];
|
||||
else
|
||||
OPENVINO_THROW("Incompatible Brgemm batch dimension");
|
||||
}
|
||||
}
|
||||
output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M
|
||||
output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N
|
||||
|
||||
// removing the temporary axes from originally 1D tensors.
|
||||
if (arg0_shape.size() == 1) {
|
||||
output_shape.erase(output_shape.begin() + output_shape.size() - 2);
|
||||
}
|
||||
if (arg1_shape.size() == 1) {
|
||||
output_shape.erase(output_shape.begin() + output_shape.size() - 1);
|
||||
}
|
||||
output_shape = ov::snippets::utils::get_planar_vdims(output_shape, m_io_layouts.back());
|
||||
return {{output_shape}, snippets::ShapeInferStatus::success};
|
||||
}
|
||||
|
||||
} // namespace snippets
|
||||
} // namespace ov
|
||||
|
@ -58,11 +58,11 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry
|
||||
SHAPE_INFER_PREDEFINED(op::Kernel, EmptyShapeInfer),
|
||||
SHAPE_INFER_PREDEFINED(op::Nop, EmptyShapeInfer),
|
||||
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(opset1::Select, SelectShapeInfer),
|
||||
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::Brgemm, BrgemmShapeInfer),
|
||||
// Note that Result has no output PortConnectors, so the shape must be empty
|
||||
SHAPE_INFER_PREDEFINED(ov::op::v0::Result, EmptyShapeInfer),
|
||||
//
|
||||
SHAPE_INFER_OP_SPECIFIC(op::LoadReshape),
|
||||
SHAPE_INFER_OP_SPECIFIC(op::Brgemm),
|
||||
SHAPE_INFER_OP_SPECIFIC(op::BroadcastLoad),
|
||||
SHAPE_INFER_OP_SPECIFIC(op::BroadcastMove),
|
||||
};
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "snippets/pass/matmul_to_brgemm.hpp"
|
||||
#include "utils/cpu_utils.hpp"
|
||||
#include "emitters/x64/cpu_generator.hpp"
|
||||
#include "transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.hpp"
|
||||
#include "transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp"
|
||||
#include "transformations/snippets/x64/pass/lowered/brgemm_blocking.hpp"
|
||||
#include "transformations/snippets/x64/pass/mul_add_to_fma.hpp"
|
||||
@ -618,6 +619,7 @@ void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp)
|
||||
|
||||
ov::snippets::lowered::pass::PassPipeline control_flow_pipeline;
|
||||
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert)
|
||||
CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape);
|
||||
// Note: we need to pass valid shapeInfer factory to generate, so it can be used in OptimizeDomain pass
|
||||
// in all other cases nGraph shape inference will be used until ticket # 113209 (PR 18563) is merged
|
||||
schedule = snippet_for_generation->generate(backend_passes,
|
||||
|
@ -57,39 +57,34 @@ void BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vector<size_t
|
||||
// During ctor call, BrgemmCopyB doesn't know his port descriptors.
|
||||
// So we use port descs from source inputs
|
||||
const auto element_type = get_input_element_type(0);
|
||||
const auto pshape = snippets::utils::get_planar_pshape(get_input_partial_shape(0), layout_input);
|
||||
validate(pshape, element_type);
|
||||
const auto& pshape = get_input_partial_shape(0);
|
||||
// The data always store in planar shape after repacking
|
||||
const auto planar_pshape = snippets::utils::get_planar_pshape(pshape, layout_input);
|
||||
// data repacking output
|
||||
set_output_type(0, element_type, planar_pshape);
|
||||
// If compensations are needed, they are provided in 2nd output (which is used in BrgemmCPU)
|
||||
if (is_with_compensations()) {
|
||||
set_output_type(1, ov::element::f32, planar_pshape);
|
||||
}
|
||||
validate(planar_pshape, element_type);
|
||||
}
|
||||
|
||||
void BrgemmCopyB::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types);
|
||||
|
||||
const auto port = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(input(0));
|
||||
const auto shape = ov::Shape(port->get_shape());
|
||||
const auto& element_type = get_input_element_type(0);
|
||||
const auto& pshape = snippets::utils::get_planar_pshape(input(0));
|
||||
validate(pshape, element_type);
|
||||
const auto& planar_pshape = snippets::utils::get_planar_pshape(shape, port->get_layout());
|
||||
set_output_type(0, element_type, planar_pshape);
|
||||
if (is_with_compensations()) {
|
||||
set_output_type(1, ov::element::f32, planar_pshape);
|
||||
}
|
||||
validate(planar_pshape, element_type);
|
||||
}
|
||||
|
||||
void BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov::element::Type& element_type) {
|
||||
NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8),
|
||||
void BrgemmCopyB::validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type) {
|
||||
OPENVINO_ASSERT(one_of(element_type, element::bf16, element::i8),
|
||||
"BrgemmCopyB doesn't support element type" + element_type.get_type_name());
|
||||
|
||||
if (pshape.is_dynamic()) {
|
||||
set_output_type(0, element_type, ov::PartialShape {ov::Dimension::dynamic()});
|
||||
if (is_with_compensations()) {
|
||||
set_output_type(1, ov::element::f32, ov::PartialShape {ov::Dimension::dynamic()});
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
const auto shape = pshape.get_shape();
|
||||
const auto N = *shape.rbegin();
|
||||
const auto K = *(shape.rbegin() + 1);
|
||||
|
||||
set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, m_brgemmVNNIFactor)),
|
||||
ov::Dimension(rnd_up(N, m_N_blk))});
|
||||
if (is_with_compensations()) {
|
||||
set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, m_N_blk))});
|
||||
}
|
||||
}
|
||||
|
||||
void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n) {
|
||||
@ -98,6 +93,17 @@ void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k,
|
||||
m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape.rbegin();
|
||||
}
|
||||
|
||||
ov::Shape intel_cpu::BrgemmCopyB::get_data_repacking_shape(const ov::snippets::VectorDims& planar_dims) const {
|
||||
const auto& N = *planar_dims.rbegin();
|
||||
const auto& K = *(planar_dims.rbegin() + 1);
|
||||
return ov::Shape{rnd_up(K, m_brgemmVNNIFactor), rnd_up(N, m_N_blk)};
|
||||
}
|
||||
|
||||
ov::Shape intel_cpu::BrgemmCopyB::get_compensation_shape(const ov::snippets::VectorDims& planar_dims) const {
|
||||
const auto& N = *planar_dims.rbegin();
|
||||
return ov::Shape{rnd_up(N, m_N_blk)};
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
@ -120,29 +126,13 @@ BrgemmCopyB::ShapeInfer::ShapeInfer(const std::shared_ptr<ov::Node>& n) {
|
||||
OPENVINO_ASSERT(brg_copyb, "Got invalid node in BrgemmCopyB::ShapeInfer");
|
||||
m_layout = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(n->input(0))->get_layout();
|
||||
m_num_outs = brg_copyb->get_output_size();
|
||||
m_N_blk = brg_copyb->get_n_block_size();
|
||||
m_brgemmVNNIFactor = brg_copyb->m_brgemmVNNIFactor;
|
||||
}
|
||||
|
||||
snippets::IShapeInferSnippets::Result BrgemmCopyB::ShapeInfer::infer(const std::vector<snippets::VectorDimsRef>& input_shapes) {
|
||||
ov::snippets::IShapeInferSnippets::Result BrgemmCopyB::ShapeInfer::infer(const std::vector<ov::snippets::VectorDimsRef>& input_shapes) {
|
||||
OPENVINO_ASSERT(input_shapes.size() == 1, "Got unexpected number of input shapes");
|
||||
const auto& old_shape = input_shapes[0].get();
|
||||
snippets::VectorDims planar_shape;
|
||||
planar_shape.reserve(old_shape.size());
|
||||
for (const auto idx : m_layout)
|
||||
planar_shape.push_back(old_shape[idx]);
|
||||
const auto N = *planar_shape.rbegin();
|
||||
const auto K = *(planar_shape.rbegin() + 1);
|
||||
OPENVINO_ASSERT(N != DYNAMIC_DIMENSION && K != DYNAMIC_DIMENSION,
|
||||
"BrgemmCopyB shape infer got dynamic N or K dimension, which is not supported");
|
||||
|
||||
std::vector<snippets::VectorDims> new_shapes(m_num_outs);
|
||||
new_shapes[0].push_back(rnd_up(K, m_brgemmVNNIFactor));
|
||||
new_shapes[0].push_back(rnd_up(N, m_N_blk));
|
||||
if (m_num_outs == 2) {
|
||||
new_shapes[1].push_back(rnd_up(N, m_N_blk));
|
||||
}
|
||||
return {new_shapes, snippets::ShapeInferStatus::success};
|
||||
const auto planar_shape = ov::snippets::utils::get_planar_vdims(input_shapes[0].get(), m_layout);
|
||||
std::vector<ov::snippets::VectorDims> new_shapes(m_num_outs, planar_shape);
|
||||
return {new_shapes, ov::snippets::ShapeInferStatus::success};
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
@ -5,6 +5,7 @@
|
||||
#pragma once
|
||||
|
||||
#include "snippets/op/memory_access.hpp"
|
||||
#include "snippets/shape_types.hpp"
|
||||
#include <snippets/shape_inference/shape_inference.hpp>
|
||||
|
||||
namespace ov {
|
||||
@ -43,7 +44,11 @@ public:
|
||||
void set_k_block_size(size_t block_size) { m_K_blk = block_size; }
|
||||
void set_n_block_size(size_t block_size) { m_N_blk = block_size; }
|
||||
|
||||
ov::Shape get_data_repacking_shape(const ov::snippets::VectorDims& planar_dims) const;
|
||||
ov::Shape get_compensation_shape(const ov::snippets::VectorDims& planar_dims) const;
|
||||
|
||||
Type get_type() const { return m_type; }
|
||||
size_t get_brgemm_vnni_factor() const { return m_brgemmVNNIFactor; }
|
||||
element::Type get_src_element_type() const { return m_src_type; }
|
||||
bool is_with_compensations() const { return m_type == Type::WithCompensations; }
|
||||
|
||||
@ -55,8 +60,6 @@ public:
|
||||
class ShapeInfer : public snippets::IShapeInferSnippets {
|
||||
std::vector<size_t> m_layout{};
|
||||
size_t m_num_outs = 1;
|
||||
size_t m_N_blk = 64;
|
||||
size_t m_brgemmVNNIFactor = 1;
|
||||
public:
|
||||
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
|
||||
Result infer(const std::vector<snippets::VectorDimsRef>& input_shapes) override;
|
||||
@ -64,7 +67,7 @@ public:
|
||||
|
||||
private:
|
||||
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_input = {});
|
||||
void validate(const ov::PartialShape& pshape, const ov::element::Type& element_type);
|
||||
void validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type);
|
||||
void compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n);
|
||||
|
||||
Type m_type = Type::OnlyRepacking;
|
||||
|
@ -114,21 +114,9 @@ void BrgemmCPU::validate_with_scratchpad(const ov::Shape& shape_b) const {
|
||||
// Additional check for 3rd input
|
||||
if (one_of(m_type, Type::WithCompensations, Type::AMX)) {
|
||||
const auto& pshape = get_input_partial_shape(2);
|
||||
NGRAPH_CHECK(pshape.is_static(), "BRGEMM Scratch must have static shape");
|
||||
const auto shape = pshape.to_shape();
|
||||
const auto type = get_input_element_type(2);
|
||||
OPENVINO_ASSERT(pshape.is_static(), "BRGEMM Scratch must have static shape");
|
||||
if (is_with_compensations()) {
|
||||
const auto expected_type = ov::element::f32;
|
||||
NGRAPH_CHECK(expected_type == type, "BRGEMM Scratch with compensations must have FP32 element type");
|
||||
const auto N = *shape_b.rbegin();
|
||||
// If N block size is not set, there is no meaning in validating the scratchpad shape
|
||||
if (m_N_blk != N) {
|
||||
const auto expected_shape = ov::Shape{rnd_up(N, m_N_blk)};
|
||||
NGRAPH_CHECK(expected_shape == shape, "BRGEMM Scratch with compensations must have shape {rnd_up(N, m_N_blk)}");
|
||||
}
|
||||
} else {
|
||||
NGRAPH_CHECK(ov::shape_size(shape) == SCRATCH_BYTE_SIZE && type == ov::element::u8,
|
||||
"BRGEMM Scratch for space workplace must be static, have U8 element type and size equal to " + std::to_string(SCRATCH_BYTE_SIZE));
|
||||
OPENVINO_ASSERT(get_input_element_type(2) == ov::element::f32, "BRGEMM Scratch with compensations must have FP32 element type");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -181,13 +169,5 @@ size_t BrgemmCPU::get_offset_scratch() const {
|
||||
return get_input_offset(2);
|
||||
}
|
||||
|
||||
BrgemmCPU::ShapeInfer::ShapeInfer(const std::shared_ptr<ov::Node>& n) : Brgemm::ShapeInfer(n) {
|
||||
const auto& brg = ov::as_type_ptr<BrgemmCPU>(n);
|
||||
OPENVINO_ASSERT(brg, "Got invalid node in BrgemmCPU::ShapeInfer");
|
||||
const auto brgemm_copy = brg->is_with_data_repacking() ? brg->get_brgemm_copy() : nullptr;
|
||||
if (brgemm_copy)
|
||||
m_io_layouts[1] = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(brgemm_copy->input(0))->get_layout();
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -69,12 +69,6 @@ public:
|
||||
|
||||
constexpr static size_t SCRATCH_BYTE_SIZE = 32 * 1024;
|
||||
|
||||
class ShapeInfer : public Brgemm::ShapeInfer {
|
||||
public:
|
||||
explicit ShapeInfer(const std::shared_ptr<ov::Node>& n);
|
||||
};
|
||||
|
||||
|
||||
private:
|
||||
void custom_constructor_validate_and_infer_types(std::vector<size_t> layout_a, std::vector<size_t> layout_b, std::vector<size_t> layout_c);
|
||||
void compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n);
|
||||
|
@ -0,0 +1,38 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "snippets/itt.hpp"
|
||||
|
||||
#include "set_brgemm_copy_b_buffers_shape.hpp"
|
||||
#include "snippets/snippets_isa.hpp"
|
||||
#include "snippets/utils.hpp"
|
||||
|
||||
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
|
||||
|
||||
bool ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape::run(snippets::lowered::LinearIR& linear_ir) {
|
||||
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::SetBrgemmCopyBBuffersShape")
|
||||
|
||||
auto get_buffer_from_output = [](const snippets::lowered::ExpressionPtr& expr, const size_t out_idx) {
|
||||
const auto& consumers = expr->get_output_port_connector(out_idx)->get_consumers();
|
||||
OPENVINO_ASSERT(consumers.size() == 1, "BrgemmCopyB must have only 1 consumer");
|
||||
const auto buffer = ov::as_type_ptr<ov::snippets::op::Buffer>(consumers.begin()->get_expr()->get_node());
|
||||
OPENVINO_ASSERT(buffer, "BrgemmCopyB consumer must be Buffer");
|
||||
return buffer;
|
||||
};
|
||||
|
||||
bool modified = false;
|
||||
for (const auto& expr : linear_ir) {
|
||||
if (auto copy_b = ov::as_type_ptr<ov::intel_cpu::BrgemmCopyB>(expr->get_node())) {
|
||||
const auto buffer = get_buffer_from_output(expr, 0);
|
||||
const auto& out_desc = expr->get_output_port_descriptor(0);
|
||||
buffer->set_allocation_shape(copy_b->get_data_repacking_shape(out_desc->get_shape()));
|
||||
if (copy_b->is_with_compensations()) {
|
||||
const auto compensations_buffer = get_buffer_from_output(expr, 1);
|
||||
compensations_buffer->set_allocation_shape(copy_b->get_compensation_shape(out_desc->get_shape()));
|
||||
}
|
||||
modified = true;
|
||||
}
|
||||
}
|
||||
return modified;
|
||||
}
|
@ -0,0 +1,27 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "snippets/lowered/pass/pass.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace pass {
|
||||
|
||||
/**
|
||||
* @interface SetBrgemmCopyBBuffersShape
|
||||
* @brief Sets the allocation shape for the Buffers after BrgemmCopyB node using BrgemmCopyB parameters
|
||||
* @ingroup snippets
|
||||
*/
|
||||
class SetBrgemmCopyBBuffersShape: public snippets::lowered::pass::Pass {
|
||||
public:
|
||||
SetBrgemmCopyBBuffersShape() = default;
|
||||
OPENVINO_RTTI("SetBrgemmCopyBBuffersShape", "Pass");
|
||||
bool run(snippets::lowered::LinearIR& linear_ir) override;
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -22,18 +22,6 @@
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
using namespace snippets::lowered;
|
||||
namespace {
|
||||
template <typename T>
|
||||
void change_desc_shape(const T& port) {
|
||||
const auto desc = PortDescriptorUtils::get_port_descriptor_ptr(port);
|
||||
const auto& shape = port.get_shape();
|
||||
if (desc->get_shape() != shape) {
|
||||
desc->set_shape(shape);
|
||||
}
|
||||
}
|
||||
} // namespace
|
||||
|
||||
pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() {
|
||||
MATCHER_SCOPE(SetBrgemmCPUBlockingParams);
|
||||
|
||||
@ -73,7 +61,7 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() {
|
||||
|
||||
const bool isAMXSupported = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx);
|
||||
const auto precision = brgemm_copy_b->get_src_element_type();
|
||||
const auto brgemmVNNIFactor = 4 / precision.size();
|
||||
const auto brgemmVNNIFactor = brgemm_copy_b->get_brgemm_vnni_factor();
|
||||
const bool use_amx = isAMXSupported && precision != ov::element::f32 && (K % brgemmVNNIFactor == 0) && (N % brgemmVNNIFactor == 0);
|
||||
|
||||
const size_t copy_b_block_size_k = use_amx ? brgemm_block_size_k : K;
|
||||
@ -81,18 +69,8 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() {
|
||||
|
||||
brgemm_copy_b->set_k_block_size(copy_b_block_size_k);
|
||||
brgemm_copy_b->set_n_block_size(copy_b_block_size_n);
|
||||
// since N block size affects output shapes, the validation must be called explicitly right after the block size changing
|
||||
brgemm_copy_b->validate_and_infer_types();
|
||||
change_desc_shape(brgemm_copy_b->output(0));
|
||||
if (brgemm_copy_b->is_with_compensations())
|
||||
change_desc_shape(brgemm_copy_b->output(1));
|
||||
}
|
||||
|
||||
brgemm->validate_and_infer_types();
|
||||
change_desc_shape(brgemm->input(1));
|
||||
if (brgemm->is_with_scratchpad())
|
||||
change_desc_shape(brgemm->input(2));
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
|
@ -28,6 +28,8 @@ ShapeInferPtr CPUShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov
|
||||
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<InferType>();} }
|
||||
#define SHAPE_INFER_OP_SPECIFIC(OP) \
|
||||
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<OP::ShapeInfer>(n);} }
|
||||
#define SHAPE_INFER_OP_SPECIFIC_EXTERNAL(OP, InferType) \
|
||||
{ OP::get_type_info_static(), [](const std::shared_ptr<ov::Node>& n) { return std::make_shared<InferType>(n);} }
|
||||
|
||||
const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::specific_ops_registry {
|
||||
SHAPE_INFER_PREDEFINED(ov::intel_cpu::FusedMulAdd, NumpyBroadcastShapeInfer),
|
||||
@ -36,9 +38,9 @@ const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::spec
|
||||
SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertTruncation, PassThroughShapeInfer),
|
||||
SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertSaturation, PassThroughShapeInfer),
|
||||
SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertTruncation, PassThroughShapeInfer),
|
||||
SHAPE_INFER_OP_SPECIFIC_EXTERNAL(ov::intel_cpu::BrgemmCPU, BrgemmShapeInfer),
|
||||
//
|
||||
SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCopyB),
|
||||
SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCPU),
|
||||
};
|
||||
#undef SHAPE_INFER_OP_SPECIFIC
|
||||
#undef SHAPE_INFER_PREDEFINED
|
||||
|
Loading…
Reference in New Issue
Block a user