From 377e92714905922de0b20979458425b0282c2dd0 Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Thu, 12 Oct 2023 14:34:53 +0200 Subject: [PATCH] [Snippets] Changed BrgemmCopyB shape inference (#19957) --- .../snippets/include/snippets/op/brgemm.hpp | 8 -- .../snippets/include/snippets/op/buffer.hpp | 3 +- .../shape_inference/shape_infer_instances.hpp | 7 ++ src/common/snippets/src/op/brgemm.cpp | 71 ---------------- .../shape_inference/shape_infer_instances.cpp | 72 ++++++++++++++++ .../src/shape_inference/shape_inference.cpp | 2 +- src/plugins/intel_cpu/src/nodes/subgraph.cpp | 2 + .../snippets/x64/op/brgemm_copy_b.cpp | 82 ++++++++----------- .../snippets/x64/op/brgemm_copy_b.hpp | 9 +- .../snippets/x64/op/brgemm_cpu.cpp | 24 +----- .../snippets/x64/op/brgemm_cpu.hpp | 6 -- .../set_brgemm_copy_b_buffers_shape.cpp | 38 +++++++++ .../set_brgemm_copy_b_buffers_shape.hpp | 27 ++++++ .../pass/set_brgemm_cpu_blocking_params.cpp | 24 +----- .../snippets/x64/shape_inference.cpp | 4 +- 15 files changed, 197 insertions(+), 182 deletions(-) create mode 100644 src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.cpp create mode 100644 src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.hpp diff --git a/src/common/snippets/include/snippets/op/brgemm.hpp b/src/common/snippets/include/snippets/op/brgemm.hpp index 50cca60bbbc..8ba681fb8e9 100644 --- a/src/common/snippets/include/snippets/op/brgemm.hpp +++ b/src/common/snippets/include/snippets/op/brgemm.hpp @@ -39,14 +39,6 @@ public: bool has_evaluate() const override { return false; } - class ShapeInfer : public IShapeInferSnippets { - protected: - std::vector> m_io_layouts; - public: - explicit ShapeInfer(const std::shared_ptr& n); - Result infer(const std::vector& input_shapes) override; - }; - protected: ov::element::Type get_output_type() const; std::vector get_planar_input_shapes(const std::vector>& inputs) const; diff --git a/src/common/snippets/include/snippets/op/buffer.hpp b/src/common/snippets/include/snippets/op/buffer.hpp index 7a644644dd7..9f522ed3d45 100644 --- a/src/common/snippets/include/snippets/op/buffer.hpp +++ b/src/common/snippets/include/snippets/op/buffer.hpp @@ -44,9 +44,10 @@ public: size_t get_id() const { return m_id; } Type get_type() const { return m_type; } - ov::Shape get_allocation_shape() const { return m_shape; } int64_t get_offset() const { return m_offset; } void set_id(size_t id) { m_id = id; } + const ov::Shape& get_allocation_shape() const { return m_shape; } + void set_allocation_shape(const ov::Shape& allocation_shape) { m_shape = allocation_shape; } void set_offset(int64_t offset) { m_offset = offset; } size_t get_byte_size() const; diff --git a/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp b/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp index f673f8ff997..af69ad90511 100644 --- a/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp +++ b/src/common/snippets/include/snippets/shape_inference/shape_infer_instances.hpp @@ -61,5 +61,12 @@ public: Result infer(const std::vector& input_shapes) override; }; +class BrgemmShapeInfer : public IShapeInferSnippets { + std::vector> m_io_layouts; +public: + explicit BrgemmShapeInfer(const std::shared_ptr& n); + Result infer(const std::vector& input_shapes) override; +}; + } // namespace snippets } // namespace ov diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 1f415a4f64b..b64a4328a83 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -188,77 +188,6 @@ ov::PartialShape Brgemm::get_output_partial_shape(const std::vector& n) { - for (const auto& in : n->inputs()) { - const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in); - m_io_layouts.push_back(port->get_layout()); - } - m_io_layouts.push_back(get_output_layout(n)); -} - -IShapeInferSnippets::Result Brgemm::ShapeInfer::infer(const std::vector& input_shapes) { - OPENVINO_ASSERT(input_shapes.size() == 2, "BRGEMM expects 2 input shapes for shape inference"); - - // Todo: Ideally we should use the layout stored in PortDescriptors. Can we do it? - const auto& arg0_shape = snippets::utils::get_planar_vdims(input_shapes[0].get(), m_io_layouts[0]); - const auto& arg1_shape = snippets::utils::get_planar_vdims(input_shapes[1].get(), m_io_layouts[1]); - - size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size(); - - // temporary shapes to calculate output shape - VectorDims arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape); - - // one-dimensional tensors unsqueezing is applied to each input independently. - if (arg0_rank == 1) { - // If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector) - // by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape. - // For example {S} will be reshaped to {1, S}. - arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1); - arg0_rank = arg0_shape_tmp.size(); - } - if (arg1_rank == 1) { - // If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector) - // by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape. - // For example {S} will be reshaped to {S, 1}. - arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1); - arg1_rank = arg1_shape_tmp.size(); - } - - // add 1 to begin to align shape ranks if needed - if (arg0_rank < arg1_rank) - arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1); - else if (arg0_rank > arg1_rank) - arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1); - - size_t max_rank = arg0_shape_tmp.size(); - VectorDims output_shape(max_rank); - for (size_t i = 0; i < max_rank - 2; ++i) { - if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) { - output_shape[i] = arg0_shape_tmp[i]; - } else { - if (arg0_shape_tmp[i] == 1 || arg0_shape_tmp[i] == DYNAMIC_DIMENSION) - output_shape[i] = arg1_shape_tmp[i]; - else if (arg1_shape_tmp[i] == 1 || arg1_shape_tmp[i] == DYNAMIC_DIMENSION) - output_shape[i] = arg0_shape_tmp[i]; - else - OPENVINO_THROW("Incompatible Brgemm batch dimension"); - } - } - output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M - output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N - - // removing the temporary axes from originally 1D tensors. - if (arg0_shape.size() == 1) { - output_shape.erase(output_shape.begin() + output_shape.size() - 2); - } - if (arg1_shape.size() == 1) { - output_shape.erase(output_shape.begin() + output_shape.size() - 1); - } - output_shape = snippets::utils::get_planar_vdims(output_shape, m_io_layouts[2]); - return {{output_shape}, snippets::ShapeInferStatus::success}; -} - } // namespace op } // namespace snippets } // namespace ov diff --git a/src/common/snippets/src/shape_inference/shape_infer_instances.cpp b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp index e8df307a0b9..61404d208fd 100644 --- a/src/common/snippets/src/shape_inference/shape_infer_instances.cpp +++ b/src/common/snippets/src/shape_inference/shape_infer_instances.cpp @@ -3,6 +3,7 @@ // #include "snippets/shape_inference/shape_infer_instances.hpp" #include "snippets/snippets_isa.hpp" +#include "snippets/utils.hpp" #include "openvino/op/select.hpp" namespace ov { namespace snippets { @@ -160,5 +161,76 @@ Result HorizonOpShapeInfer::infer(const std::vector& input_shapes return {{output_shapes}, ShapeInferStatus::success}; } +BrgemmShapeInfer::BrgemmShapeInfer(const std::shared_ptr& n) { + for (const auto& in : n->inputs()) { + const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(in); + m_io_layouts.push_back(port->get_layout()); + } + const auto& port = lowered::PortDescriptorUtils::get_port_descriptor_ptr(n->output(0)); + m_io_layouts.push_back(port->get_layout()); +} + +Result BrgemmShapeInfer::infer(const std::vector& input_shapes) { + OPENVINO_ASSERT(input_shapes.size() == 2 || input_shapes.size() == 3, "BRGEMM expects 2 or 3 input shapes for shape inference"); + + // Todo: Ideally we should use the layout stored in PortDescriptors. Can we do it? + const auto& arg0_shape = ov::snippets::utils::get_planar_vdims(input_shapes[0].get(), m_io_layouts[0]); + const auto& arg1_shape = ov::snippets::utils::get_planar_vdims(input_shapes[1].get(), m_io_layouts[1]); + + size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size(); + + // temporary shapes to calculate output shape + VectorDims arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape); + + // one-dimensional tensors unsqueezing is applied to each input independently. + if (arg0_rank == 1) { + // If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector) + // by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape. + // For example {S} will be reshaped to {1, S}. + arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1); + arg0_rank = arg0_shape_tmp.size(); + } + if (arg1_rank == 1) { + // If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector) + // by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape. + // For example {S} will be reshaped to {S, 1}. + arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1); + arg1_rank = arg1_shape_tmp.size(); + } + + // add 1 to begin to align shape ranks if needed + if (arg0_rank < arg1_rank) + arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1); + else if (arg0_rank > arg1_rank) + arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1); + + size_t max_rank = arg0_shape_tmp.size(); + VectorDims output_shape(max_rank); + for (size_t i = 0; i < max_rank - 2; ++i) { + if (arg0_shape_tmp[i] == arg1_shape_tmp[i]) { + output_shape[i] = arg0_shape_tmp[i]; + } else { + if (arg0_shape_tmp[i] == 1 || arg0_shape_tmp[i] == DYNAMIC_DIMENSION) + output_shape[i] = arg1_shape_tmp[i]; + else if (arg1_shape_tmp[i] == 1 || arg1_shape_tmp[i] == DYNAMIC_DIMENSION) + output_shape[i] = arg0_shape_tmp[i]; + else + OPENVINO_THROW("Incompatible Brgemm batch dimension"); + } + } + output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M + output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N + + // removing the temporary axes from originally 1D tensors. + if (arg0_shape.size() == 1) { + output_shape.erase(output_shape.begin() + output_shape.size() - 2); + } + if (arg1_shape.size() == 1) { + output_shape.erase(output_shape.begin() + output_shape.size() - 1); + } + output_shape = ov::snippets::utils::get_planar_vdims(output_shape, m_io_layouts.back()); + return {{output_shape}, snippets::ShapeInferStatus::success}; +} + } // namespace snippets } // namespace ov diff --git a/src/common/snippets/src/shape_inference/shape_inference.cpp b/src/common/snippets/src/shape_inference/shape_inference.cpp index cfc4dc460d4..22470a13d34 100644 --- a/src/common/snippets/src/shape_inference/shape_inference.cpp +++ b/src/common/snippets/src/shape_inference/shape_inference.cpp @@ -58,11 +58,11 @@ const IShapeInferSnippetsFactory::TRegistry IShapeInferSnippetsFactory::registry SHAPE_INFER_PREDEFINED(op::Kernel, EmptyShapeInfer), SHAPE_INFER_PREDEFINED(op::Nop, EmptyShapeInfer), SHAPE_INFER_OP_SPECIFIC_EXTERNAL(opset1::Select, SelectShapeInfer), + SHAPE_INFER_OP_SPECIFIC_EXTERNAL(op::Brgemm, BrgemmShapeInfer), // Note that Result has no output PortConnectors, so the shape must be empty SHAPE_INFER_PREDEFINED(ov::op::v0::Result, EmptyShapeInfer), // SHAPE_INFER_OP_SPECIFIC(op::LoadReshape), - SHAPE_INFER_OP_SPECIFIC(op::Brgemm), SHAPE_INFER_OP_SPECIFIC(op::BroadcastLoad), SHAPE_INFER_OP_SPECIFIC(op::BroadcastMove), }; diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index dd2c756ba63..c20ecbea76c 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -22,6 +22,7 @@ #include "snippets/pass/matmul_to_brgemm.hpp" #include "utils/cpu_utils.hpp" #include "emitters/x64/cpu_generator.hpp" +#include "transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.hpp" #include "transformations/snippets/x64/pass/lowered/fuse_load_store_and_convert.hpp" #include "transformations/snippets/x64/pass/lowered/brgemm_blocking.hpp" #include "transformations/snippets/x64/pass/mul_add_to_fma.hpp" @@ -618,6 +619,7 @@ void Snippet::SnippetJitExecutor::generate(const jit_snippets_compile_args* jcp) ov::snippets::lowered::pass::PassPipeline control_flow_pipeline; CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::FuseLoadStoreConvert) + CPU_REGISTER_PASS_X64(control_flow_pipeline, ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape); // Note: we need to pass valid shapeInfer factory to generate, so it can be used in OptimizeDomain pass // in all other cases nGraph shape inference will be used until ticket # 113209 (PR 18563) is merged schedule = snippet_for_generation->generate(backend_passes, diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp index e16088a1567..643b5d74fc9 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.cpp @@ -57,39 +57,34 @@ void BrgemmCopyB::custom_constructor_validate_and_infer_types(std::vectorget_shape()); const auto& element_type = get_input_element_type(0); - const auto& pshape = snippets::utils::get_planar_pshape(input(0)); - validate(pshape, element_type); + const auto& planar_pshape = snippets::utils::get_planar_pshape(shape, port->get_layout()); + set_output_type(0, element_type, planar_pshape); + if (is_with_compensations()) { + set_output_type(1, ov::element::f32, planar_pshape); + } + validate(planar_pshape, element_type); } -void BrgemmCopyB::validate(const ov::PartialShape& pshape, const ov::element::Type& element_type) { - NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8), - "BrgemmCopyB doesn't support element type" + element_type.get_type_name()); - - if (pshape.is_dynamic()) { - set_output_type(0, element_type, ov::PartialShape {ov::Dimension::dynamic()}); - if (is_with_compensations()) { - set_output_type(1, ov::element::f32, ov::PartialShape {ov::Dimension::dynamic()}); - } - return; - } - - const auto shape = pshape.get_shape(); - const auto N = *shape.rbegin(); - const auto K = *(shape.rbegin() + 1); - - set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, m_brgemmVNNIFactor)), - ov::Dimension(rnd_up(N, m_N_blk))}); - if (is_with_compensations()) { - set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, m_N_blk))}); - } +void BrgemmCopyB::validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type) { + OPENVINO_ASSERT(one_of(element_type, element::bf16, element::i8), + "BrgemmCopyB doesn't support element type" + element_type.get_type_name()); } void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n) { @@ -98,6 +93,17 @@ void intel_cpu::BrgemmCopyB::compute_block_size_values(const size_t blk_size_k, m_N_blk = blk_size_n != 0 ? blk_size_n : *input_shape.rbegin(); } +ov::Shape intel_cpu::BrgemmCopyB::get_data_repacking_shape(const ov::snippets::VectorDims& planar_dims) const { + const auto& N = *planar_dims.rbegin(); + const auto& K = *(planar_dims.rbegin() + 1); + return ov::Shape{rnd_up(K, m_brgemmVNNIFactor), rnd_up(N, m_N_blk)}; +} + +ov::Shape intel_cpu::BrgemmCopyB::get_compensation_shape(const ov::snippets::VectorDims& planar_dims) const { + const auto& N = *planar_dims.rbegin(); + return ov::Shape{rnd_up(N, m_N_blk)}; +} + std::shared_ptr intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs); check_new_args_count(this, new_args); @@ -120,29 +126,13 @@ BrgemmCopyB::ShapeInfer::ShapeInfer(const std::shared_ptr& n) { OPENVINO_ASSERT(brg_copyb, "Got invalid node in BrgemmCopyB::ShapeInfer"); m_layout = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(n->input(0))->get_layout(); m_num_outs = brg_copyb->get_output_size(); - m_N_blk = brg_copyb->get_n_block_size(); - m_brgemmVNNIFactor = brg_copyb->m_brgemmVNNIFactor; } -snippets::IShapeInferSnippets::Result BrgemmCopyB::ShapeInfer::infer(const std::vector& input_shapes) { +ov::snippets::IShapeInferSnippets::Result BrgemmCopyB::ShapeInfer::infer(const std::vector& input_shapes) { OPENVINO_ASSERT(input_shapes.size() == 1, "Got unexpected number of input shapes"); - const auto& old_shape = input_shapes[0].get(); - snippets::VectorDims planar_shape; - planar_shape.reserve(old_shape.size()); - for (const auto idx : m_layout) - planar_shape.push_back(old_shape[idx]); - const auto N = *planar_shape.rbegin(); - const auto K = *(planar_shape.rbegin() + 1); - OPENVINO_ASSERT(N != DYNAMIC_DIMENSION && K != DYNAMIC_DIMENSION, - "BrgemmCopyB shape infer got dynamic N or K dimension, which is not supported"); - - std::vector new_shapes(m_num_outs); - new_shapes[0].push_back(rnd_up(K, m_brgemmVNNIFactor)); - new_shapes[0].push_back(rnd_up(N, m_N_blk)); - if (m_num_outs == 2) { - new_shapes[1].push_back(rnd_up(N, m_N_blk)); - } - return {new_shapes, snippets::ShapeInferStatus::success}; + const auto planar_shape = ov::snippets::utils::get_planar_vdims(input_shapes[0].get(), m_layout); + std::vector new_shapes(m_num_outs, planar_shape); + return {new_shapes, ov::snippets::ShapeInferStatus::success}; } } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp index 62703049aea..9274ad026e5 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_copy_b.hpp @@ -5,6 +5,7 @@ #pragma once #include "snippets/op/memory_access.hpp" +#include "snippets/shape_types.hpp" #include namespace ov { @@ -43,7 +44,11 @@ public: void set_k_block_size(size_t block_size) { m_K_blk = block_size; } void set_n_block_size(size_t block_size) { m_N_blk = block_size; } + ov::Shape get_data_repacking_shape(const ov::snippets::VectorDims& planar_dims) const; + ov::Shape get_compensation_shape(const ov::snippets::VectorDims& planar_dims) const; + Type get_type() const { return m_type; } + size_t get_brgemm_vnni_factor() const { return m_brgemmVNNIFactor; } element::Type get_src_element_type() const { return m_src_type; } bool is_with_compensations() const { return m_type == Type::WithCompensations; } @@ -55,8 +60,6 @@ public: class ShapeInfer : public snippets::IShapeInferSnippets { std::vector m_layout{}; size_t m_num_outs = 1; - size_t m_N_blk = 64; - size_t m_brgemmVNNIFactor = 1; public: explicit ShapeInfer(const std::shared_ptr& n); Result infer(const std::vector& input_shapes) override; @@ -64,7 +67,7 @@ public: private: void custom_constructor_validate_and_infer_types(std::vector layout_input = {}); - void validate(const ov::PartialShape& pshape, const ov::element::Type& element_type); + void validate(const ov::PartialShape& planar_pshape, const ov::element::Type& element_type); void compute_block_size_values(const size_t blk_size_k, const size_t blk_size_n); Type m_type = Type::OnlyRepacking; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp index 03e3325376c..20f7fccafe3 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.cpp @@ -114,21 +114,9 @@ void BrgemmCPU::validate_with_scratchpad(const ov::Shape& shape_b) const { // Additional check for 3rd input if (one_of(m_type, Type::WithCompensations, Type::AMX)) { const auto& pshape = get_input_partial_shape(2); - NGRAPH_CHECK(pshape.is_static(), "BRGEMM Scratch must have static shape"); - const auto shape = pshape.to_shape(); - const auto type = get_input_element_type(2); + OPENVINO_ASSERT(pshape.is_static(), "BRGEMM Scratch must have static shape"); if (is_with_compensations()) { - const auto expected_type = ov::element::f32; - NGRAPH_CHECK(expected_type == type, "BRGEMM Scratch with compensations must have FP32 element type"); - const auto N = *shape_b.rbegin(); - // If N block size is not set, there is no meaning in validating the scratchpad shape - if (m_N_blk != N) { - const auto expected_shape = ov::Shape{rnd_up(N, m_N_blk)}; - NGRAPH_CHECK(expected_shape == shape, "BRGEMM Scratch with compensations must have shape {rnd_up(N, m_N_blk)}"); - } - } else { - NGRAPH_CHECK(ov::shape_size(shape) == SCRATCH_BYTE_SIZE && type == ov::element::u8, - "BRGEMM Scratch for space workplace must be static, have U8 element type and size equal to " + std::to_string(SCRATCH_BYTE_SIZE)); + OPENVINO_ASSERT(get_input_element_type(2) == ov::element::f32, "BRGEMM Scratch with compensations must have FP32 element type"); } } } @@ -181,13 +169,5 @@ size_t BrgemmCPU::get_offset_scratch() const { return get_input_offset(2); } -BrgemmCPU::ShapeInfer::ShapeInfer(const std::shared_ptr& n) : Brgemm::ShapeInfer(n) { - const auto& brg = ov::as_type_ptr(n); - OPENVINO_ASSERT(brg, "Got invalid node in BrgemmCPU::ShapeInfer"); - const auto brgemm_copy = brg->is_with_data_repacking() ? brg->get_brgemm_copy() : nullptr; - if (brgemm_copy) - m_io_layouts[1] = snippets::lowered::PortDescriptorUtils::get_port_descriptor_ptr(brgemm_copy->input(0))->get_layout(); -} - } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp index e1957bb66d2..bf07b7a8546 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/op/brgemm_cpu.hpp @@ -69,12 +69,6 @@ public: constexpr static size_t SCRATCH_BYTE_SIZE = 32 * 1024; - class ShapeInfer : public Brgemm::ShapeInfer { - public: - explicit ShapeInfer(const std::shared_ptr& n); - }; - - private: void custom_constructor_validate_and_infer_types(std::vector layout_a, std::vector layout_b, std::vector layout_c); void compute_block_size_values(const size_t blk_size_m, const size_t blk_size_k, const size_t blk_size_n); diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.cpp new file mode 100644 index 00000000000..91bec8aee60 --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.cpp @@ -0,0 +1,38 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/itt.hpp" + +#include "set_brgemm_copy_b_buffers_shape.hpp" +#include "snippets/snippets_isa.hpp" +#include "snippets/utils.hpp" + +#include "transformations/snippets/x64/op/brgemm_copy_b.hpp" + +bool ov::intel_cpu::pass::SetBrgemmCopyBBuffersShape::run(snippets::lowered::LinearIR& linear_ir) { + OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::SetBrgemmCopyBBuffersShape") + + auto get_buffer_from_output = [](const snippets::lowered::ExpressionPtr& expr, const size_t out_idx) { + const auto& consumers = expr->get_output_port_connector(out_idx)->get_consumers(); + OPENVINO_ASSERT(consumers.size() == 1, "BrgemmCopyB must have only 1 consumer"); + const auto buffer = ov::as_type_ptr(consumers.begin()->get_expr()->get_node()); + OPENVINO_ASSERT(buffer, "BrgemmCopyB consumer must be Buffer"); + return buffer; + }; + + bool modified = false; + for (const auto& expr : linear_ir) { + if (auto copy_b = ov::as_type_ptr(expr->get_node())) { + const auto buffer = get_buffer_from_output(expr, 0); + const auto& out_desc = expr->get_output_port_descriptor(0); + buffer->set_allocation_shape(copy_b->get_data_repacking_shape(out_desc->get_shape())); + if (copy_b->is_with_compensations()) { + const auto compensations_buffer = get_buffer_from_output(expr, 1); + compensations_buffer->set_allocation_shape(copy_b->get_compensation_shape(out_desc->get_shape())); + } + modified = true; + } + } + return modified; +} diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.hpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.hpp new file mode 100644 index 00000000000..fcac51286e0 --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/lowered/set_brgemm_copy_b_buffers_shape.hpp @@ -0,0 +1,27 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "snippets/lowered/pass/pass.hpp" + +namespace ov { +namespace intel_cpu { +namespace pass { + +/** + * @interface SetBrgemmCopyBBuffersShape + * @brief Sets the allocation shape for the Buffers after BrgemmCopyB node using BrgemmCopyB parameters + * @ingroup snippets + */ +class SetBrgemmCopyBBuffersShape: public snippets::lowered::pass::Pass { +public: + SetBrgemmCopyBBuffersShape() = default; + OPENVINO_RTTI("SetBrgemmCopyBBuffersShape", "Pass"); + bool run(snippets::lowered::LinearIR& linear_ir) override; +}; + +} // namespace pass +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp index db6f34a4e74..df88ffa7edc 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/pass/set_brgemm_cpu_blocking_params.cpp @@ -22,18 +22,6 @@ namespace ov { namespace intel_cpu { -using namespace snippets::lowered; -namespace { -template -void change_desc_shape(const T& port) { - const auto desc = PortDescriptorUtils::get_port_descriptor_ptr(port); - const auto& shape = port.get_shape(); - if (desc->get_shape() != shape) { - desc->set_shape(shape); - } -} -} // namespace - pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() { MATCHER_SCOPE(SetBrgemmCPUBlockingParams); @@ -73,7 +61,7 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() { const bool isAMXSupported = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx); const auto precision = brgemm_copy_b->get_src_element_type(); - const auto brgemmVNNIFactor = 4 / precision.size(); + const auto brgemmVNNIFactor = brgemm_copy_b->get_brgemm_vnni_factor(); const bool use_amx = isAMXSupported && precision != ov::element::f32 && (K % brgemmVNNIFactor == 0) && (N % brgemmVNNIFactor == 0); const size_t copy_b_block_size_k = use_amx ? brgemm_block_size_k : K; @@ -81,18 +69,8 @@ pass::SetBrgemmCPUBlockingParams::SetBrgemmCPUBlockingParams() { brgemm_copy_b->set_k_block_size(copy_b_block_size_k); brgemm_copy_b->set_n_block_size(copy_b_block_size_n); - // since N block size affects output shapes, the validation must be called explicitly right after the block size changing - brgemm_copy_b->validate_and_infer_types(); - change_desc_shape(brgemm_copy_b->output(0)); - if (brgemm_copy_b->is_with_compensations()) - change_desc_shape(brgemm_copy_b->output(1)); } - brgemm->validate_and_infer_types(); - change_desc_shape(brgemm->input(1)); - if (brgemm->is_with_scratchpad()) - change_desc_shape(brgemm->input(2)); - return false; }; diff --git a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp index d09f3f218e6..6bb833262a5 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/x64/shape_inference.cpp @@ -28,6 +28,8 @@ ShapeInferPtr CPUShapeInferSnippetsFactory::get_specific_op_shape_infer(const ov { OP::get_type_info_static(), [](const std::shared_ptr& n) { return std::make_shared();} } #define SHAPE_INFER_OP_SPECIFIC(OP) \ { OP::get_type_info_static(), [](const std::shared_ptr& n) { return std::make_shared(n);} } +#define SHAPE_INFER_OP_SPECIFIC_EXTERNAL(OP, InferType) \ + { OP::get_type_info_static(), [](const std::shared_ptr& n) { return std::make_shared(n);} } const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::specific_ops_registry { SHAPE_INFER_PREDEFINED(ov::intel_cpu::FusedMulAdd, NumpyBroadcastShapeInfer), @@ -36,9 +38,9 @@ const CPUShapeInferSnippetsFactory::TRegistry CPUShapeInferSnippetsFactory::spec SHAPE_INFER_PREDEFINED(ov::intel_cpu::LoadConvertTruncation, PassThroughShapeInfer), SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertSaturation, PassThroughShapeInfer), SHAPE_INFER_PREDEFINED(ov::intel_cpu::StoreConvertTruncation, PassThroughShapeInfer), + SHAPE_INFER_OP_SPECIFIC_EXTERNAL(ov::intel_cpu::BrgemmCPU, BrgemmShapeInfer), // SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCopyB), - SHAPE_INFER_OP_SPECIFIC(ov::intel_cpu::BrgemmCPU), }; #undef SHAPE_INFER_OP_SPECIFIC #undef SHAPE_INFER_PREDEFINED