diff --git a/src/common/snippets/CMakeLists.txt b/src/common/snippets/CMakeLists.txt index d3a7e47c604..a37d5343092 100644 --- a/src/common/snippets/CMakeLists.txt +++ b/src/common/snippets/CMakeLists.txt @@ -26,7 +26,7 @@ ie_faster_build(${TARGET_NAME} ) target_link_libraries(${TARGET_NAME} PUBLIC openvino::runtime - PRIVATE ngraph_reference ov_shape_inference openvino::runtime::dev) + PRIVATE ngraph_reference openvino::runtime::dev) target_include_directories(${TARGET_NAME} PUBLIC $ PRIVATE $) diff --git a/src/common/snippets/include/snippets/generator.hpp b/src/common/snippets/include/snippets/generator.hpp index 939b4f4d43c..48715235c11 100644 --- a/src/common/snippets/include/snippets/generator.hpp +++ b/src/common/snippets/include/snippets/generator.hpp @@ -43,7 +43,6 @@ public: */ virtual size_t get_lanes() const = 0; - /** * @brief called by generator to all the emitter for a target machine * @return a map by node's type info with callbacks to create an instance of emitter for corresponding operation type @@ -155,7 +154,29 @@ public: */ std::shared_ptr get_target_machine() const; + /** + * @interface opRegType + * @brief Register type of operations + * Note that currently there are 4 types of ops: + * gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd etc) + * gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc. + * vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc. + */ + enum opRegType {gpr2gpr, gpr2vec, vec2gpr, vec2vec}; + /** + * @brief gets register type by op type + * TODO: Should be static attribute of emitters + * @return register type + */ + opRegType get_op_reg_type(const std::shared_ptr& op) const; + protected: + /** + * @brief gets register type by specific plugin op type + * @return register type + */ + virtual opRegType get_specific_op_reg_type(const std::shared_ptr& op) const; + std::shared_ptr target; // todo: we need to save lowered code to access compiled brgemm kernels on execution time (normally lowered is destructed by then). // This is temporary solution, remove this when kernel caching is implemented. Don't forget to make generate const method. diff --git a/src/common/snippets/include/snippets/op/brgemm.hpp b/src/common/snippets/include/snippets/op/brgemm.hpp index 2746d974a06..58c70f16479 100644 --- a/src/common/snippets/include/snippets/op/brgemm.hpp +++ b/src/common/snippets/include/snippets/op/brgemm.hpp @@ -5,7 +5,7 @@ #pragma once #include "ngraph/op/op.hpp" -#include "ngraph/op/matmul.hpp" +#include "memory_access.hpp" namespace ngraph { namespace snippets { @@ -16,30 +16,25 @@ namespace op { * @brief Brgemm is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows * @ingroup snippets */ -class Brgemm : public ngraph::op::v0::MatMul { +class Brgemm : public MemoryAccess { public: - OPENVINO_OP("Brgemm", "SnippetsOpset", ngraph::op::v0::MatMul); - Brgemm(const Output& A, const Output& B, const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu); + OPENVINO_OP("Brgemm", "SnippetsOpset", MemoryAccess); + Brgemm(const Output& A, const Output& B, + const size_t offset_a = 0lu, const size_t offset_b = 0lu, const size_t offset_c = 0lu); Brgemm() = default; - bool visit_attributes(AttributeVisitor& visitor) override; + size_t get_offset_a() const { return get_input_offset(0); } + size_t get_offset_b() const { return get_input_offset(1); } + size_t get_offset_c() const { return get_output_offset(0); } + void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; bool has_evaluate() const override { return false; } - size_t get_offset_a() const { return m_offset_a; } - size_t get_offset_b() const { return m_offset_b; } - size_t get_offset_c() const { return m_offset_c; } - - void set_offset_a(const size_t offset) { m_offset_a = offset; } - void set_offset_b(const size_t offset) { m_offset_b = offset; } - void set_offset_c(const size_t offset) { m_offset_c = offset; } - -private: - size_t m_offset_a = 0lu; // offset for first input - size_t m_offset_b = 0lu; // offset for second input - size_t m_offset_c = 0lu; // offset for output +protected: + ov::element::Type get_output_type() const; + ov::PartialShape get_output_partial_shape(const std::vector& input_shapes) const; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/broadcastload.hpp b/src/common/snippets/include/snippets/op/broadcastload.hpp index 43f3a329adc..edcbe170a37 100644 --- a/src/common/snippets/include/snippets/op/broadcastload.hpp +++ b/src/common/snippets/include/snippets/op/broadcastload.hpp @@ -4,7 +4,7 @@ #pragma once -#include +#include #include "ngraph/op/op.hpp" @@ -17,22 +17,21 @@ namespace op { * @brief Is generated for broadcasting by least varying dimension for non-blocked cases and the second varying dimension for blocked * @ingroup snippets */ -class BroadcastLoad : public BroadcastMove { +class BroadcastLoad : public MemoryAccess { public: - OPENVINO_OP("BroadcastLoad", "SnippetsOpset", ngraph::snippets::op::BroadcastMove); + OPENVINO_OP("BroadcastLoad", "SnippetsOpset", ngraph::snippets::op::MemoryAccess); BroadcastLoad(const Output& x, ov::PartialShape output_shape, size_t offset = 0lu); BroadcastLoad() = default; - size_t get_offset() const { return m_offset; } - void set_offset(const size_t offset) { m_offset = offset; } + size_t get_offset() const { return get_input_offset(0); } bool visit_attributes(AttributeVisitor& visitor) override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; void validate_and_infer_types() override; private: - size_t m_offset = 0lu; + ov::PartialShape output_shape; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/buffer.hpp b/src/common/snippets/include/snippets/op/buffer.hpp index f75fc95e742..8c6f98ac894 100644 --- a/src/common/snippets/include/snippets/op/buffer.hpp +++ b/src/common/snippets/include/snippets/op/buffer.hpp @@ -12,10 +12,9 @@ namespace op { /** * @interface Buffer - * @brief The operation is for intermediate data storage - * - m_allocation_rank - rank of shape for memory allocation: shape[shape_rank - normalize(m_allocation_rank) : shape_rank]. - * It's needed to allocate needed memory size that depends on Tile rank, for example. - * Default value is -1 (full shape) + * @brief This is a base class for memory storage. + * If Buffer has a parent, the operation is for intermediate data storage - IntermediateMemory type. + * Otherwise, the operation is for allocation of new empty memory with shape `m_shape` - NewMemory type * Notes: * - All buffers in a graph have the same memory pointer. So if we have a few buffers, * each the corresponding MemoryAccess op for Buffer should have offset for common memory pointer of this Buffer @@ -25,21 +24,30 @@ namespace op { class Buffer : public ngraph::op::Op { public: OPENVINO_OP("Buffer", "SnippetsOpset"); - - Buffer(const Output& x, const int32_t allocation_rank = -1); Buffer() = default; - - int32_t get_allocation_rank() const { return m_allocation_rank; } - void set_allocation_rank(int32_t rank) { m_allocation_rank = rank; } - - size_t get_byte_size() const; + Buffer(const ov::Shape& shape); + Buffer(const ov::Output& arg, const ov::Shape& shape); + Buffer(const ov::Output& arg, int32_t allocation_rank = -1); bool visit_attributes(AttributeVisitor& visitor) override; - std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + + enum Type { + NewMemory, + IntermediateMemory + }; + + Type get_type() const { return m_type; } + ov::Shape get_allocation_shape() const { return m_shape; } + size_t get_byte_size() const; + + bool is_intermediate_memory() const { return m_type == Type::IntermediateMemory; } + bool is_new_memory() const { return m_type == Type::NewMemory; } private: - int32_t m_allocation_rank = -1; + Type m_type = Type::IntermediateMemory; + ov::Shape m_shape = {}; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/load.hpp b/src/common/snippets/include/snippets/op/load.hpp index bd0a4c5463f..38acd0e8a10 100644 --- a/src/common/snippets/include/snippets/op/load.hpp +++ b/src/common/snippets/include/snippets/op/load.hpp @@ -20,11 +20,18 @@ namespace op { */ class Load : public MemoryAccess { public: - OPENVINO_OP("Load", "SnippetsOpset"); + OPENVINO_OP("Load", "SnippetsOpset", MemoryAccess); Load(const Output& x, const size_t count = 1lu, const size_t offset = 0lu); Load() = default; + size_t get_offset() const { return get_input_offset(0); } + size_t get_count() const { return get_input_count(0); } + + void set_offset(size_t offset) { set_input_offset(offset, 0); } + void set_count(size_t count) { set_input_count(count, 0); } + + void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; }; @@ -41,6 +48,9 @@ public: LoadReshape(const Output& x, size_t count = 1lu, const size_t offset = 0lu, std::vector order = {}); LoadReshape() = default; + void set_offset(size_t offset) { set_output_offset(offset, 0); } + void set_count(size_t count) { set_output_count(count, 0); } + bool visit_attributes(AttributeVisitor& visitor) override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; void validate_and_infer_types() override; diff --git a/src/common/snippets/include/snippets/op/memory_access.hpp b/src/common/snippets/include/snippets/op/memory_access.hpp index f1b2d8ebb2f..7b090c8f65d 100644 --- a/src/common/snippets/include/snippets/op/memory_access.hpp +++ b/src/common/snippets/include/snippets/op/memory_access.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2018-2022 Intel Corporation +// Copyright (C) 2018-2023 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -13,9 +13,9 @@ namespace op { /** * @interface MemoryAccess * @brief This is a base class for memory access operations (like Load and Store). - * It provides universal set/get interface to manipulate the number - * of elements accessed during one operation call ("count"). - * Default "count" value is "1" - it means to load/store one element + * It provides universal interface to manipulate with memory: load/store. + * @param m_input_ports - vector of input descriptors: variables of PortDescriptor class + * @param m_output_ports - vector of output descriptors: variables of PortDescriptor class * @ingroup snippets */ @@ -23,18 +23,54 @@ class MemoryAccess : public ngraph::op::Op { public: OPENVINO_OP("MemoryAccess", "SnippetsOpset"); - size_t get_count() const; - size_t get_offset() const; - void set_count(const size_t count); - void set_offset(const size_t offset); + /** + * @interface PortDescriptor + * @brief This class describes port of MemoryAccess operation + * @param m_count - count of elements to load/store + * @param m_offset - starting index of elements to load/store + * @param m_index - port index + * @ingroup snippets + */ + struct PortDescriptor { + PortDescriptor(size_t count, size_t offset) : count(count), offset(offset) {} + PortDescriptor() = default; + + size_t count = 0lu; + size_t offset = 0lu; + size_t index = 0lu; + + private: + PortDescriptor(size_t count, size_t offset, size_t index) : count(count), offset(offset), index(index) {} + + friend class MemoryAccess; + }; + + void set_input_count(size_t count, size_t idx = 0); + void set_output_count(size_t count, size_t idx = 0); + void set_input_offset(size_t offset, size_t idx = 0); + void set_output_offset(size_t offset, size_t idx = 0); + + size_t get_input_count(size_t idx = 0) const; + size_t get_output_count(size_t idx = 0) const; + size_t get_input_offset(size_t idx = 0) const; + size_t get_output_offset(size_t idx = 0) const; + + size_t get_input_port_count() const { return m_input_ports.size(); } + size_t get_output_port_count() const { return m_output_ports.size(); } + bool visit_attributes(AttributeVisitor& visitor) override; - void validate_and_infer_types() override; protected: - explicit MemoryAccess(const Output& x, size_t count = 1lu, size_t offset = 0lu); + explicit MemoryAccess(const OutputVector& arguments, size_t input_count = 0, size_t output_count = 0); MemoryAccess() = default; - size_t m_count = 0lu; - size_t m_offset = 0lu; + + void set_input_port_descriptor(const PortDescriptor& desc, const size_t i); + void set_output_port_descriptor(const PortDescriptor& desc, const size_t i); + const PortDescriptor& get_input_port_descriptor(const size_t i) const; + const PortDescriptor& get_output_port_descriptor(const size_t i) const; + + std::vector m_input_ports; + std::vector m_output_ports; }; } // namespace op diff --git a/src/common/snippets/include/snippets/op/store.hpp b/src/common/snippets/include/snippets/op/store.hpp index 38715cffc6c..b62a4c6ccb1 100644 --- a/src/common/snippets/include/snippets/op/store.hpp +++ b/src/common/snippets/include/snippets/op/store.hpp @@ -20,11 +20,18 @@ namespace op { */ class Store : public MemoryAccess { public: - OPENVINO_OP("Store", "SnippetsOpset"); + OPENVINO_OP("Store", "SnippetsOpset", MemoryAccess); Store(const Output& x, const size_t count = 1lu, const size_t offset = 0lu); Store() = default; + size_t get_offset() const { return get_output_offset(0); } + size_t get_count() const { return get_output_count(0); } + + void set_offset(size_t offset) { set_output_offset(offset, 0); } + void set_count(size_t count) { set_output_count(count, 0); } + + void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; }; diff --git a/src/common/snippets/include/snippets/pass/assign_registers.hpp b/src/common/snippets/include/snippets/pass/assign_registers.hpp index 144a8678784..81a5e3b2b29 100644 --- a/src/common/snippets/include/snippets/pass/assign_registers.hpp +++ b/src/common/snippets/include/snippets/pass/assign_registers.hpp @@ -6,6 +6,8 @@ #include +#include "snippets/generator.hpp" + namespace ngraph { namespace snippets { namespace pass { @@ -18,10 +20,13 @@ namespace pass { */ class AssignRegisters : public ngraph::pass::FunctionPass { public: - explicit AssignRegisters() { + explicit AssignRegisters(const std::function& op)>& mapper) : m_reg_type_mapper(mapper) { set_property(ngraph::pass::PassProperty::REQUIRE_STATIC_SHAPE, true); } bool run_on_model(const std::shared_ptr& m) override; + +private: + std::function& op)> m_reg_type_mapper; }; } // namespace pass diff --git a/src/common/snippets/include/snippets/utils.hpp b/src/common/snippets/include/snippets/utils.hpp index 253785b516d..3325ff42446 100644 --- a/src/common/snippets/include/snippets/utils.hpp +++ b/src/common/snippets/include/snippets/utils.hpp @@ -29,10 +29,31 @@ ov::PartialShape get_port_planar_shape(const Output& out); ov::PartialShape get_reordered_planar_shape(const ov::PartialShape& shape, const std::vector& layout); std::vector get_node_output_layout(const std::shared_ptr& node); std::vector get_node_output_layout(const Node* node); +void set_transpose_output_layout(const ov::Output& port, const std::shared_ptr& node); +void set_output_layout(const ov::Output& port, const std::vector& layout); inline ov::Dimension get_inner_dim(const ov::PartialShape &shape) { return *(shape.rbegin()); } inline ov::Dimension get_outer_dim(const ov::PartialShape &shape) { return *(shape.rbegin() + 1); } +inline auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t { + return allocation_rank < 0 ? allocation_rank + static_cast(shape_rank) + 1 : allocation_rank; +} + +template +constexpr bool one_of(T val, P item) { return val == item; } + +template +constexpr bool one_of(T val, P item, Args... item_others) { + return val == item || one_of(val, item_others...); +} + +template +constexpr bool everyone_is(T val, P item) { return val == item; } + +template +constexpr bool everyone_is(T val, P item, Args... item_others) { + return val == item && everyone_is(val, item_others...); +} } // namespace utils } // namespace snippets } // namespace ngraph diff --git a/src/common/snippets/src/generator.cpp b/src/common/snippets/src/generator.cpp index 5ff9b9e19e2..dba0f139fda 100644 --- a/src/common/snippets/src/generator.cpp +++ b/src/common/snippets/src/generator.cpp @@ -77,8 +77,15 @@ auto tail_transformations(NodeVector& tail, const size_t tail_size, const ngraph } } } else if (const auto memory_access = std::dynamic_pointer_cast(op)) { - if (memory_access->get_count() != 1) { - memory_access->set_count(tail_size); + for (size_t i = 0; i < memory_access->get_input_port_count(); ++i) { + if (memory_access->get_input_count(i) > 1) { + memory_access->set_input_count(tail_size, i); + } + } + for (size_t i = 0; i < memory_access->get_output_port_count(); ++i) { + if (memory_access->get_output_count(i) > 1) { + memory_access->set_output_count(tail_size, i); + } } } updated_tile.push_back(op); @@ -220,5 +227,41 @@ std::shared_ptr Generator::get_target_machine() const { return target; } +Generator::opRegType Generator::get_op_reg_type(const std::shared_ptr& op) const { + if (std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return gpr2gpr; + else if (std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return gpr2vec; + else if (std::dynamic_pointer_cast(op)) + return vec2gpr; + else if (ov::op::util::is_unary_elementwise_arithmetic(op) || + ov::op::util::is_binary_elementwise_arithmetic(op) || + ov::op::util::is_binary_elementwise_comparison(op) || + ov::op::util::is_binary_elementwise_logical(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return vec2vec; + else + return get_specific_op_reg_type(op); +} + +Generator::opRegType Generator::get_specific_op_reg_type(const std::shared_ptr& op) const { + throw ov::Exception("Register type of the operation " + std::string(op->get_type_name()) + " isn't determined!"); +} + + }// namespace snippets }// namespace ngraph diff --git a/src/common/snippets/src/op/brgemm.cpp b/src/common/snippets/src/op/brgemm.cpp index 7bf999cb15e..743653099b8 100644 --- a/src/common/snippets/src/op/brgemm.cpp +++ b/src/common/snippets/src/op/brgemm.cpp @@ -7,56 +7,123 @@ #include "ngraph/runtime/host_tensor.hpp" #include "openvino/core/rt_info.hpp" #include "snippets/utils.hpp" -#include "matmul_shape_inference.hpp" namespace ngraph { namespace snippets { namespace op { -Brgemm::Brgemm(const Output& A, const Output& B, const size_t offset_a, const size_t offset_b, const size_t offset_c) - : MatMul(), m_offset_a(offset_a), m_offset_b(offset_b), m_offset_c(offset_c) { - set_arguments({A, B}); +Brgemm::Brgemm(const Output& A, const Output& B, + const size_t offset_a, const size_t offset_b, const size_t offset_c) : MemoryAccess({A, B}, 2, 1) { set_output_size(1); + set_input_offset(offset_a, 0); + set_input_offset(offset_b, 1); + set_output_offset(offset_a, 0); constructor_validate_and_infer_types(); } -bool Brgemm::visit_attributes(AttributeVisitor& visitor) { - MatMul::visit_attributes(visitor); - visitor.on_attribute("offset_a", m_offset_a); - visitor.on_attribute("offset_b", m_offset_b); - visitor.on_attribute("offset_c", m_offset_c); - return true; -} - void Brgemm::validate_and_infer_types() { INTERNAL_OP_SCOPE(Brgemm_validate_and_infer_types); - element::Type result_et; - NODE_VALIDATION_CHECK(this, - element::Type::merge(result_et, get_input_element_type(0), get_input_element_type(1)), - "Arguments do not have the same element type (arg0 element type: ", - get_input_element_type(0), - ", arg1 element type: ", - get_input_element_type(1), - ")."); // If no leading dimensions are provided, assume dense row-major inputs-outputs NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(), "Brgemm currently supports only static shapes."); - std::vector planar_input_shapes; - for (const auto& in : input_values()) - planar_input_shapes.emplace_back(utils::get_port_planar_shape(in)); + std::vector planar_input_shapes = { + utils::get_port_planar_shape(input_value(0)), + utils::get_port_planar_shape(input_value(1)) + }; - std::vector output_shapes = {ov::PartialShape{}}; - ov::op::v0::shape_infer(this, planar_input_shapes, output_shapes); + auto output_shape = get_output_partial_shape(planar_input_shapes); const auto& output_layout = utils::get_node_output_layout(this); - output_shapes[0] = utils::get_reordered_planar_shape(output_shapes[0], output_layout); - set_output_type(0, result_et, output_shapes[0]); + set_output_type(0, + get_output_type(), + utils::get_reordered_planar_shape(output_shape, output_layout)); } std::shared_ptr Brgemm::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(Brgemm_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), new_args.at(1), m_offset_a, m_offset_b, m_offset_c); + return std::make_shared(new_args.at(0), new_args.at(1), get_offset_a(), get_offset_b(), get_offset_c()); +} + +ov::element::Type Brgemm::get_output_type() const { + const auto element_type_a = get_input_element_type(0); + const auto element_type_b = get_input_element_type(1); + const bool is_f32 = utils::everyone_is(element::f32, element_type_a, element_type_b); + const bool is_int8 = utils::one_of(element_type_a, element::i8, element::u8) && element_type_b == element::i8; + const bool is_bf16 = utils::everyone_is(element::bf16, element_type_a, element_type_b); + if (is_f32 || is_bf16) { + return element::f32; + } else if (is_int8) { + return element::i32; + } else { + throw ngraph_error("BrgemmCPU node has incompatible input element types: " + + element_type_a.get_type_name() + + " and " + + element_type_b.get_type_name()); + } +} + +ov::PartialShape Brgemm::get_output_partial_shape(const std::vector& input_shapes) const { + NGRAPH_CHECK(input_shapes.size() == 2, "BRGEMM expects 2 input shapes for shape inference"); + + // Note: All majors checks are missed because Brgemm is transformed from MatMul with whole shape infer support + + const auto arg0_shape = input_shapes[0]; + const auto arg1_shape = input_shapes[1]; + + size_t arg0_rank = arg0_shape.size(), arg1_rank = arg1_shape.size(); + + // temporary shapes to calculate output shape + ov::PartialShape arg0_shape_tmp(arg0_shape), arg1_shape_tmp(arg1_shape); + + // one-dimensional tensors unsqueezing is applied to each input independently. + if (arg0_rank == 1) { + // If the first input is 1D tensor, it is unsqueezed to 2D tensor (row vector) + // by adding axes with size 1 at ROW_INDEX_DIM, to the left of the shape. + // For example {S} will be reshaped to {1, S}. + arg0_shape_tmp.insert(arg0_shape_tmp.begin(), 1); + arg0_rank = arg0_shape_tmp.size(); + } + if (arg1_rank == 1) { + // If the second input is 1D tensor, it is unsqueezed to 2D tensor (column vector) + // by adding axes with size 1 at COL_INDEX_DIM, to the right of the shape. + // For example {S} will be reshaped to {S, 1}. + arg1_shape_tmp.insert(arg1_shape_tmp.end(), 1); + arg1_rank = arg1_shape_tmp.size(); + } + // Check matrices dimensions compatibility, + using DimType = typename std::iterator_traits::value_type; + auto merged_dimension = DimType(); + auto arg0_col_dim = arg0_shape_tmp[arg0_rank - 1]; + auto arg1_row_dim = arg1_shape_tmp[arg1_rank - 2]; + OPENVINO_ASSERT(DimType::merge(merged_dimension, arg0_col_dim, arg1_row_dim) || arg0_col_dim.is_dynamic() || arg1_row_dim.is_dynamic(), + "Incompatible Brgemm matrix dimension"); + + // add 1 to begin to align shape ranks if needed + if (arg0_rank < arg1_rank) + arg0_shape_tmp.insert(arg0_shape_tmp.begin(), arg1_rank - arg0_rank, 1); + else if (arg0_rank > arg1_rank) + arg1_shape_tmp.insert(arg1_shape_tmp.begin(), arg0_rank - arg1_rank, 1); + + size_t max_rank = arg0_shape_tmp.size(); + std::vector output_shape(max_rank); + for (size_t i = 0; i < max_rank - 2; ++i) { + OPENVINO_ASSERT(DimType::broadcast_merge(output_shape[i], arg0_shape_tmp[i], arg1_shape_tmp[i]) || + arg0_shape_tmp[i].is_dynamic() || + arg1_shape_tmp[i].is_dynamic(), + "Incompatible Brgemm batch dimension"); + } + output_shape[output_shape.size() - 2] = arg0_shape_tmp[arg0_shape_tmp.size() - 2]; // M + output_shape[output_shape.size() - 1] = arg1_shape_tmp[arg1_shape_tmp.size() - 1]; // N + + // removing the temporary axes from originally 1D tensors. + if (arg0_shape.rank().get_length() == 1) { + output_shape.erase(output_shape.begin() + output_shape.size() - 2); + } + if (arg1_shape.rank().get_length() == 1) { + output_shape.erase(output_shape.begin() + output_shape.size() - 1); + } + return output_shape; } } // namespace op diff --git a/src/common/snippets/src/op/broadcastload.cpp b/src/common/snippets/src/op/broadcastload.cpp index 0f4e6c7667e..ccbb5f9b9af 100644 --- a/src/common/snippets/src/op/broadcastload.cpp +++ b/src/common/snippets/src/op/broadcastload.cpp @@ -12,20 +12,20 @@ using namespace std; using namespace ngraph; snippets::op::BroadcastLoad::BroadcastLoad(const Output& x, ov::PartialShape shape, size_t offset) - : BroadcastMove(x, std::move(shape)), m_offset(offset) { + : MemoryAccess({x}, 1, 0), output_shape(std::move(shape)) { + set_input_port_descriptor({1, offset}, 0); constructor_validate_and_infer_types(); } bool snippets::op::BroadcastLoad::visit_attributes(AttributeVisitor& visitor) { - BroadcastMove::visit_attributes(visitor); - visitor.on_attribute("offset", m_offset); + MemoryAccess::visit_attributes(visitor); return true; } std::shared_ptr snippets::op::BroadcastLoad::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(BroadcastLoad); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), output_shape, m_offset); + return std::make_shared(new_args.at(0), output_shape, get_offset()); } void snippets::op::BroadcastLoad::validate_and_infer_types() { diff --git a/src/common/snippets/src/op/buffer.cpp b/src/common/snippets/src/op/buffer.cpp index ad05ae2e046..8a3963119b8 100644 --- a/src/common/snippets/src/op/buffer.cpp +++ b/src/common/snippets/src/op/buffer.cpp @@ -6,8 +6,8 @@ #include "snippets/op/buffer.hpp" #include "snippets/snippets_isa.hpp" +#include "snippets/utils.hpp" -#include using namespace std; using namespace ngraph; @@ -16,38 +16,64 @@ auto normalize_rank(int32_t allocation_rank, const size_t shape_rank) -> int32_t return allocation_rank < 0 ? allocation_rank + static_cast(shape_rank) : allocation_rank; } -snippets::op::Buffer::Buffer(const Output& x, const int32_t allocation_rank) : Op({x}), m_allocation_rank(allocation_rank) { +snippets::op::Buffer::Buffer(const ov::Shape& shape) + : Op(), m_type(Type::NewMemory), m_shape(shape) { + constructor_validate_and_infer_types(); +} + +snippets::op::Buffer::Buffer(const ov::Output& arg, const ov::Shape& shape) + : Op({arg}), m_type(Type::IntermediateMemory), m_shape(shape) { + constructor_validate_and_infer_types(); +} + +snippets::op::Buffer::Buffer(const ov::Output& arg, int32_t allocation_rank) + : Op({arg}), m_type(Type::IntermediateMemory) { + const auto pshape = arg.get_partial_shape(); + OPENVINO_ASSERT(pshape.is_static(), "Buffer supports only static input shape"); + const auto shape = pshape.get_shape(); + const auto normalize_rank = utils::normalize_rank(static_cast(allocation_rank), shape.size()); + const auto offset = static_cast(shape.size()) - normalize_rank; + m_shape = {shape.begin() + offset, shape.end()}; constructor_validate_and_infer_types(); } bool snippets::op::Buffer::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(Buffer_visit_attributes); - visitor.on_attribute("allocation_rank", m_allocation_rank); + visitor.on_attribute("allocation_shape", m_shape); return true; } +void snippets::op::Buffer::validate_and_infer_types() { + INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types); + ov::element::Type output_type; + ov::Shape output_shape; + if (m_type == Type::NewMemory) { + OPENVINO_ASSERT(get_input_size() == 0, "Buffer with new allocated memory must to not have arguments!"); + output_shape = m_shape; + output_type = ov::element::u8; // 1Byte + } else if (m_type == Type::IntermediateMemory) { + const auto input_shape = get_input_partial_shape(0); + OPENVINO_ASSERT(input_shape.is_static(), "Buffer supports only static input shape"); + output_type = get_input_element_type(0); + output_shape = input_shape.get_shape(); + } else { + throw ov::Exception("Buffer supports only the following types: NewMemory and IntermediateMemory"); + } + set_output_type(0, output_type, output_shape); +} + std::shared_ptr snippets::op::Buffer::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(Buffer_clone_with_new_inputs); check_new_args_count(this, new_args); - auto new_buffer = std::make_shared(new_args.at(0), m_allocation_rank); - return new_buffer; -} - -void snippets::op::Buffer::validate_and_infer_types() { - INTERNAL_OP_SCOPE(Buffer_validate_and_infer_types); - const auto shape_rank = get_input_partial_shape(0).rank(); - if (shape_rank.is_static()) { - const auto normalized_rank = normalize_rank(m_allocation_rank, shape_rank.get_length()); - NGRAPH_CHECK(normalized_rank >= 0 && normalized_rank <= shape_rank.get_length(), - "Buffer has incorrect allocation rank: " + std::to_string(m_allocation_rank)); + if (m_type == Type::NewMemory) { + return std::make_shared(m_shape); + } else if (m_type == Type::IntermediateMemory) { + return std::make_shared(new_args.at(0), m_shape); } - set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); + throw ov::Exception("Buffer supports only the following types: NewMemory and IntermediateMemory"); } size_t ngraph::snippets::op::Buffer::get_byte_size() const { - const auto pshape = get_input_partial_shape(0); - NGRAPH_CHECK(pshape.is_static(), "Buffer should have static shapes for memory allocation"); - const auto shape = pshape.get_shape(); - const auto normalized_rank = normalize_rank(m_allocation_rank, shape.size()); - return ngraph::shape_size(shape.rbegin(), shape.rbegin() + normalized_rank) * get_element_type().size(); + const auto shape = get_allocation_shape(); + return ngraph::shape_size(shape) * get_element_type().size(); } diff --git a/src/common/snippets/src/op/load.cpp b/src/common/snippets/src/op/load.cpp index 8ee227c7afb..f1f5bc42c7a 100644 --- a/src/common/snippets/src/op/load.cpp +++ b/src/common/snippets/src/op/load.cpp @@ -12,17 +12,24 @@ namespace ngraph { namespace snippets { namespace op { -Load::Load(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}, count, offset) { +Load::Load(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}, 1, 0) { + set_input_port_descriptor({count, offset}, 0); constructor_validate_and_infer_types(); } +void snippets::op::Load::validate_and_infer_types() { + // Load has memory access port only on output + OPENVINO_ASSERT(get_input_port_count() == 1, "Load node must have memory access input port"); + OPENVINO_ASSERT(get_output_port_count() == 0, "Load node mustn't have memory access output port"); + set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); +} + std::shared_ptr Load::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(Load); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_count, m_offset); + return std::make_shared(new_args.at(0), get_count(), get_offset()); } - LoadReshape::LoadReshape(const Output& x, const size_t count, const size_t offset, std::vector order) : Load(x, count, offset), m_order(std::move(order)) { const auto& in_shape = x.get_partial_shape(); @@ -33,6 +40,8 @@ LoadReshape::LoadReshape(const Output& x, const size_t count, const si *std::min_element(m_order.begin(), m_order.end()) == 0, "LoadReshape detected invalid values in new_order"); const std::set unique_dims(order.begin(), order.end()); NGRAPH_CHECK(unique_dims.size() == order.size(), "LoadReshape order must not contain repeated elements"); + m_input_ports.resize(get_input_size()); + set_input_port_descriptor({count, offset}, 0); constructor_validate_and_infer_types(); } @@ -53,7 +62,7 @@ bool snippets::op::LoadReshape::visit_attributes(AttributeVisitor& visitor) { std::shared_ptr snippets::op::LoadReshape::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(LoadReshape); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_count, m_offset, m_order); + return std::make_shared(new_args.at(0), get_count(), get_offset(), m_order); } }// namespace op diff --git a/src/common/snippets/src/op/memory_access.cpp b/src/common/snippets/src/op/memory_access.cpp index 2530ea77b63..ea0e4649f9e 100644 --- a/src/common/snippets/src/op/memory_access.cpp +++ b/src/common/snippets/src/op/memory_access.cpp @@ -3,43 +3,80 @@ // #include - #include "snippets/op/memory_access.hpp" -#include - namespace ngraph { namespace snippets { namespace op { -MemoryAccess::MemoryAccess(const Output& x, const size_t count, const size_t offset) : Op({x}), m_count(count), m_offset(offset) {} +MemoryAccess::MemoryAccess(const OutputVector& arguments, size_t input_count, size_t output_count) : Op(arguments) { + while (m_input_ports.size() < input_count) { + m_input_ports.push_back({0, 0, m_input_ports.size()}); + } + while (m_output_ports.size() < output_count) { + m_output_ports.push_back({0, 0, m_output_ports.size()}); + } +} bool MemoryAccess::visit_attributes(AttributeVisitor& visitor) { - visitor.on_attribute("count", m_count); - visitor.on_attribute("offset", m_offset); + for (size_t i = 0; i < m_input_ports.size(); ++i) { + auto port = m_input_ports[i]; + visitor.on_attribute("count_in_" + std::to_string(i), port.count); + visitor.on_attribute("offset_in_" + std::to_string(i), port.offset); + } + for (size_t i = 0; i < m_output_ports.size(); ++i) { + auto port = m_output_ports[i]; + visitor.on_attribute("count_out_" + std::to_string(i), port.count); + visitor.on_attribute("offset_out_" + std::to_string(i), port.offset); + } return true; } -size_t MemoryAccess::get_count() const { - return m_count; +void MemoryAccess::set_input_port_descriptor(const PortDescriptor& desc, const size_t i) { + NGRAPH_CHECK(i < m_input_ports.size(), "Index of input port descriptor should be less than count of input ports"); + m_input_ports[i] = { desc.count, desc.offset, i}; } -size_t MemoryAccess::get_offset() const { - return m_offset; +void MemoryAccess::set_output_port_descriptor(const PortDescriptor& desc, const size_t i) { + NGRAPH_CHECK(i < m_output_ports.size(), "Index of output port descriptor should be less than count of output ports"); + m_output_ports[i] = { desc.count, desc.offset, i}; } -void MemoryAccess::set_count(const size_t count) { - m_count = count; +const MemoryAccess::PortDescriptor& MemoryAccess::get_input_port_descriptor(const size_t i) const { + NGRAPH_CHECK(i < m_input_ports.size(), "Index of input port descriptor should be less than count of input ports"); + return m_input_ports[i]; } -void MemoryAccess::set_offset(const size_t offset) { - m_offset = offset; +const MemoryAccess::PortDescriptor& MemoryAccess::get_output_port_descriptor(const size_t i) const { + NGRAPH_CHECK(i < m_output_ports.size(), "Index of output port descriptor should be less than count of output ports"); + return m_output_ports[i]; } -void MemoryAccess::validate_and_infer_types() { - set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); +void MemoryAccess::set_input_count(size_t count, size_t idx) { + set_input_port_descriptor({count, get_input_port_descriptor(idx).offset, idx}, idx); +} +void MemoryAccess::set_output_count(size_t count, size_t idx) { + set_output_port_descriptor({count, get_output_port_descriptor(idx).offset, idx}, idx); +} +void MemoryAccess::set_input_offset(size_t offset, size_t idx) { + set_input_port_descriptor({get_input_port_descriptor(idx).count, offset, idx}, idx); +} +void MemoryAccess::set_output_offset(size_t offset, size_t idx) { + set_output_port_descriptor({get_output_port_descriptor(idx).count, offset, idx}, idx); +} +size_t MemoryAccess::get_input_count(size_t idx) const { + return get_input_port_descriptor(idx).count; +} +size_t MemoryAccess::get_output_count(size_t idx) const { + return get_output_port_descriptor(idx).count; +} +size_t MemoryAccess::get_input_offset(size_t idx) const { + return get_input_port_descriptor(idx).offset; +} +size_t MemoryAccess::get_output_offset(size_t idx) const { + return get_output_port_descriptor(idx).offset; } } // namespace op } // namespace snippets -} // namespace ngraph \ No newline at end of file +} // namespace ngraph diff --git a/src/common/snippets/src/op/store.cpp b/src/common/snippets/src/op/store.cpp index 2cee1b20751..8ac2c4cdf17 100644 --- a/src/common/snippets/src/op/store.cpp +++ b/src/common/snippets/src/op/store.cpp @@ -12,13 +12,22 @@ namespace ngraph { namespace snippets { namespace op { -snippets::op::Store::Store(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}, count, offset) { +snippets::op::Store::Store(const Output& x, const size_t count, const size_t offset) : MemoryAccess({x}, 0, 1) { + set_output_port_descriptor({count, offset}, 0); constructor_validate_and_infer_types(); } + +void snippets::op::Store::validate_and_infer_types() { + // Store has memory access port only on output + OPENVINO_ASSERT(get_input_port_count() == 0, "Store node mustn't have memory access input port"); + OPENVINO_ASSERT(get_output_port_count() == 1, "Store node must have memory access output port"); + set_output_type(0, get_input_element_type(0), get_input_partial_shape(0)); +} + std::shared_ptr snippets::op::Store::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(Store_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_count, m_offset); + return std::make_shared(new_args.at(0), get_count(), get_offset()); } } // namespace op diff --git a/src/common/snippets/src/op/subgraph.cpp b/src/common/snippets/src/op/subgraph.cpp index 20b6edb17b9..f8953745520 100644 --- a/src/common/snippets/src/op/subgraph.cpp +++ b/src/common/snippets/src/op/subgraph.cpp @@ -434,22 +434,21 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { // Propagate to up: in Store. Buffer can have only one Store { - auto parent = buffer->get_input_node_shared_ptr(0); - auto idx = buffer->input(0).get_source_output().get_index(); - // There may be graph with several LoopBegin and LoopEnd between Store/Brgemm and Buffer, - // so we should iterate through LoopBase - while (ov::is_type(parent)) { - const auto source_output = parent->input_value(idx); - parent = source_output.get_node_shared_ptr(); - idx = source_output.get_index(); - } - if (auto store = ov::as_type_ptr(parent)) { - store->set_offset(offset); - } else if (const auto brgemm = ov::as_type_ptr(parent)) { - // Brgemm encapsulates work with loading and storing of data - brgemm->set_offset_c(offset); - } else { - throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding Store op for offset propagation"); + if (buffer->is_intermediate_memory()) { + OPENVINO_ASSERT(buffer->get_input_size() == 1, "Buffer with intermediate memory must have one parent"); + auto parent = buffer->get_input_node_shared_ptr(0); + auto idx = buffer->input(0).get_source_output().get_index(); + while (ov::is_type(parent)) { + const auto source_output = parent->input_value(idx); + parent = source_output.get_node_shared_ptr(); + idx = source_output.get_index(); + } + if (auto memory_access = ov::as_type_ptr(parent)) { + memory_access->set_output_offset(offset, idx); + } else { + throw ngraph_error( + "Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation"); + } } } @@ -466,17 +465,10 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { for (const auto loop_target_output : child->output(index).get_target_inputs()) { propagate_down(loop_target_output); } - } else if (const auto load = ov::as_type_ptr(child)) { - load->set_offset(offset); - } else if (const auto brgemm = ov::as_type_ptr(child)) { - // Brgemm encapsulates work with loading and storing of data - if (target_input.get_index() == 0) { - brgemm->set_offset_a(offset); - } else if (target_input.get_index() == 1) { - brgemm->set_offset_b(offset); - } + } else if (auto memory_access = ov::as_type_ptr(child)) { + memory_access->set_input_offset(offset, target_input.get_index()); } else { - throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding Load op for offset propagation"); + throw ngraph_error("Buffer::set_offset() was called when Buffer didn't have the corresponding MemoryAccess op for offset propagation"); } }; @@ -497,26 +489,25 @@ void snippets::op::Subgraph::initialize_buffer_scratchpad_size() { continue; } - // Transpose and MatMul ops should have different memories on inputs and outputs to avoid data corruption, - // so after them, we should allocate new memory. Other operations (Eltwises, Convert) can be executed inplace. - const auto parent = buffer->get_input_node_shared_ptr(0); - if (ov::is_type(parent) || is_transpose_loop(parent)) { + if (buffer->is_intermediate_memory()) { + // Transpose, MatMul and other non-decomposed ops should have different memories on inputs and outputs to avoid data corruption, + // so after them, we should allocate new memory. Other operations (Eltwises, Convert) can be executed inplace inside Loop. + OPENVINO_ASSERT(buffer->get_input_size() == 1, "Buffer with intermediate memory must have one parent"); + const auto parent = buffer->get_input_node_shared_ptr(0); + if (!ov::is_type(parent) || is_transpose_loop(parent)) { + offset = m_buffer_scratchpad; + propagate_offset(buffer, offset); + m_buffer_scratchpad += buffer_size; + continue; + } + + propagate_offset(buffer, offset); + } else { + // Single Buffer without input should allocate new memory offset = m_buffer_scratchpad; propagate_offset(buffer, offset); m_buffer_scratchpad += buffer_size; - continue; } - - // If Buffer op requires memory size more that has been already allocated, - // we increase current memory size to the needed size - // For example, it's possible when we have a sequence of Eltwise ops with broadcasting - const auto current_allocated_memory_size = m_buffer_scratchpad - offset; - if (buffer_size > current_allocated_memory_size) { - m_buffer_scratchpad += (buffer_size - current_allocated_memory_size); - // Note: we don't update offset because we just add memory to needed size - } - - propagate_offset(buffer, offset); } } } @@ -644,7 +635,10 @@ snippets::Schedule snippets::op::Subgraph::generate( if (config.m_has_domain_sensitive_ops) initialize_buffer_scratchpad_size(); - snippets::pass::AssignRegisters().run_on_model(body_ptr()); + std::function& op)> reg_type_mapper = [=](const std::shared_ptr& op) -> Generator::opRegType { + return m_generator->get_op_reg_type(op); + }; + snippets::pass::AssignRegisters(reg_type_mapper).run_on_model(body_ptr()); const auto ops = body_ptr()->get_ops(); ngraph::snippets::Generator::GeneratorConfig generatorConfig; diff --git a/src/common/snippets/src/pass/assign_registers.cpp b/src/common/snippets/src/pass/assign_registers.cpp index 3de3138db60..c9af20443b8 100644 --- a/src/common/snippets/src/pass/assign_registers.cpp +++ b/src/common/snippets/src/pass/assign_registers.cpp @@ -14,6 +14,7 @@ namespace { constexpr size_t reg_count = 16lu; +using opRegType = ngraph::snippets::Generator::opRegType; } // namespace bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr& f) { @@ -22,31 +23,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr using Reg = size_t; using tensor = std::shared_ptr; auto ops = f->get_ordered_ops(); - // Note that currently there are 3 types of ops: - // * gpr->gpr: (Parameter, Result, LoopBegin, LoopEnd) will also be Buffer? - // * gpr->vec: or vec->gpr Load/LoadConvert, Store/StoreConvert, BroadcastLoad etc. - // * vec->vec: all other "normal" operations that perform calculations on vector registers: Add, BroadcastMove, Power, etc. - enum op_reg_type {gpr2gpr, gpr2vec, vec2gpr, vec2vec}; - auto get_op_reg_type = [](const std::shared_ptr& op) { - if (std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) - return gpr2gpr; - else if (std::dynamic_pointer_cast(op) || - std::dynamic_pointer_cast(op)) - return gpr2vec; - else if (std::dynamic_pointer_cast(op)) - return vec2gpr; - else - return vec2vec; - }; - std::vector>> typed_ops; - for (const auto& op : ops) - typed_ops.emplace_back(std::make_pair(get_op_reg_type(op), op)); + std::vector>> typed_ops; + for (const auto& op : ops) { + typed_ops.emplace_back(std::make_pair(m_reg_type_mapper(op), op)); + } + size_t counter_vec = 0; size_t counter_gpr = 0; std::map regs_vec, regs_gpr; @@ -64,10 +46,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr // here we use the fact that Result input & output tensors are identical by construction manually_assigned_gprs[op->output(0).get_tensor_ptr()] = static_cast(f->get_result_index(result) + num_parameters); - } else if (const auto& buffer = ov::as_type_ptr(op)) { + } else if (const auto buffer = ov::as_type_ptr(op)) { // All buffers have one common data pointer - manually_assigned_gprs[op->input(0).get_tensor_ptr()] = - static_cast(num_results + num_parameters); + if (buffer->is_intermediate_memory()) { + manually_assigned_gprs[op->input(0).get_tensor_ptr()] = + static_cast(num_results + num_parameters); + } manually_assigned_gprs[op->output(0).get_tensor_ptr()] = static_cast(num_results + num_parameters); } else if (ov::is_type(op) || ov::is_type(op)) { @@ -114,12 +98,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr }; for (const auto& t_op : typed_ops) { switch (t_op.first) { - case vec2vec: - case gpr2vec: + case opRegType::vec2vec: + case opRegType::gpr2vec: enumerate_out_tensors(t_op.second, regs_vec, manually_assigned_vecs, counter_vec); break; - case gpr2gpr: - case vec2gpr: + case opRegType::gpr2gpr: + case opRegType::vec2gpr: enumerate_out_tensors(t_op.second, regs_gpr, manually_assigned_gprs, counter_gpr); break; } @@ -144,24 +128,25 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr for (size_t i = 0; i < typed_ops.size(); i++) { const auto& t_op = typed_ops[i]; std::vector used_tensors, defined_tensors; - for (const auto& in : t_op.second->inputs()) + for (const auto& in : t_op.second->inputs()) { used_tensors.push_back(in.get_tensor_ptr()); + } for (const auto& out : t_op.second->outputs()) defined_tensors.push_back(out.get_tensor_ptr()); switch (t_op.first) { - case vec2vec: + case opRegType::vec2vec: used_vec[i] = tensor2reg(used_tensors, regs_vec); defined_vec[i] = tensor2reg(defined_tensors, regs_vec); break; - case gpr2gpr: + case opRegType::gpr2gpr: used_gpr[i] = tensor2reg(used_tensors, regs_gpr); defined_gpr[i] = tensor2reg(defined_tensors, regs_gpr); break; - case gpr2vec: + case opRegType::gpr2vec: used_gpr[i] = tensor2reg(used_tensors, regs_gpr); defined_vec[i] = tensor2reg(defined_tensors, regs_vec); break; - case vec2gpr: + case opRegType::vec2gpr: used_vec[i] = tensor2reg(used_tensors, regs_vec); defined_gpr[i] = tensor2reg(defined_tensors, regs_gpr); break; @@ -196,12 +181,12 @@ bool ngraph::snippets::pass::AssignRegisters::run_on_model(const std::shared_ptr if (k == ops.size()) throw ngraph_error("assign registers can't find target op in the body"); switch (typed_ops[k].first) { - case vec2vec: - case vec2gpr: + case opRegType::vec2vec: + case opRegType::vec2gpr: life_out_vec[n].insert(life_in_vec[k].begin(), life_in_vec[k].end()); break; - case gpr2gpr: - case gpr2vec: + case opRegType::gpr2gpr: + case opRegType::gpr2vec: life_out_gpr[n].insert(life_in_gpr[k].begin(), life_in_gpr[k].end()); break; } diff --git a/src/common/snippets/src/pass/collapse_subgraph.cpp b/src/common/snippets/src/pass/collapse_subgraph.cpp index 3325881834f..af962adaa64 100644 --- a/src/common/snippets/src/pass/collapse_subgraph.cpp +++ b/src/common/snippets/src/pass/collapse_subgraph.cpp @@ -49,9 +49,16 @@ auto outputs_are_not_broadcastable(const std::shared_ptr& node) -> b auto is_supported_op(const std::shared_ptr &n) -> bool { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "Snippets::is_supported_op") auto is_supported_matmul = [](const std::shared_ptr& n) -> bool { - const auto& matmul = is_type(n); + const auto& matmul = ov::as_type_ptr(n); const auto& out_shape = n->get_output_partial_shape(0); - return matmul && out_shape.is_static() && out_shape.size() == 4; + if (!matmul || out_shape.is_dynamic() || out_shape.size() != 4) + return false; + const auto intype_0 = matmul->get_input_element_type(0); + const auto intype_1 = matmul->get_input_element_type(1); + const bool is_f32 = intype_0 == element::f32 && intype_1 == element::f32; + const bool is_int8 = (intype_0 == element::i8 || intype_0 == element::u8) && (intype_1 == element::i8); + const bool is_bf16 = intype_0 == element::bf16 && intype_1 == element::bf16; + return is_f32 || is_bf16 || is_int8; }; auto is_supported_transpose = [](const std::shared_ptr& n) -> bool { const auto& transpose = as_type_ptr(n); diff --git a/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp b/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp index f50731bcf7c..62dd1292b3f 100644 --- a/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp +++ b/src/common/snippets/src/pass/fuse_transpose_brgemm.cpp @@ -49,13 +49,8 @@ FuseTransposeBrgemm::FuseTransposeBrgemm() { auto callback = [=](pattern::Matcher& m) { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "FuseTransposeBrgemm") - auto set_layout_from_order = [](const std::shared_ptr& node, const ov::Output& port) { - const auto& const_order = as_type_ptr(node->get_input_node_shared_ptr(1)); - std::vector layout = const_order->cast_vector(); - auto& rt_info = port.get_node_shared_ptr()->get_rt_info(); - rt_info["Layout"] = layout; - }; auto brgemm = as_type_ptr(m.get_match_root()); + // Transpose on the Brgemm's output if (!brgemm) { brgemm = as_type_ptr(m.get_match_root()->get_input_node_shared_ptr(0)); @@ -63,13 +58,13 @@ FuseTransposeBrgemm::FuseTransposeBrgemm() { const auto& transpose_out = m.get_match_value(); for (const auto& in : transpose_out.get_target_inputs()) in.replace_source_output(brgemm->output(0)); - set_layout_from_order(as_type_ptr(transpose_out.get_node_shared_ptr()), brgemm_out); + utils::set_transpose_output_layout(brgemm_out, as_type_ptr(transpose_out.get_node_shared_ptr())); } for (size_t i = 0; i < brgemm->get_input_size(); i++) { const auto& in_value = brgemm->input_value(i); if (transpose_matcher->match(in_value)) { const auto& transpose = as_type_ptr(in_value.get_node_shared_ptr()); - set_layout_from_order(transpose, transpose->input_value(0)); + utils::set_transpose_output_layout(transpose->input_value(0), transpose); brgemm->set_argument(i, transpose->input_value(0)); } } diff --git a/src/common/snippets/src/pass/insert_buffer.cpp b/src/common/snippets/src/pass/insert_buffer.cpp index 1b080bc0b0c..e7f4c90ae02 100644 --- a/src/common/snippets/src/pass/insert_buffer.cpp +++ b/src/common/snippets/src/pass/insert_buffer.cpp @@ -31,7 +31,7 @@ ngraph::snippets::pass::InsertBuffer::InsertBuffer(const int32_t allocation_rank if (!ov::is_type(input_node) && !ov::is_type(input_node) && !ov::is_type(input_node)) { - const auto buffer = std::make_shared(input_node, allocation_rank); + const auto buffer = std::make_shared(input_node, allocation_rank); root->set_argument(input.get_index(), buffer); rewritten |= true; } @@ -68,7 +68,7 @@ ngraph::snippets::pass::InsertBuffer::InsertBuffer(const int32_t allocation_rank } } - const auto buffer = std::make_shared(output, allocation_rank); + const auto buffer = std::make_shared(output, allocation_rank); for (const auto& consumer : output.get_target_inputs()) { const auto output_node = consumer.get_node()->shared_from_this(); if (output_node != buffer && diff --git a/src/common/snippets/src/pass/insert_load_store.cpp b/src/common/snippets/src/pass/insert_load_store.cpp index ef0fed11b50..114393bd872 100644 --- a/src/common/snippets/src/pass/insert_load_store.cpp +++ b/src/common/snippets/src/pass/insert_load_store.cpp @@ -30,7 +30,7 @@ ngraph::snippets::pass::InsertLoad::InsertLoad(const size_t count) { const auto& consumer_node = consumer.get_node(); if (ov::is_type(consumer_node) || ov::is_type(consumer_node) || - ov::is_type(consumer_node) || + ov::is_type(consumer_node) || ov::is_type(consumer_node)) { return false; } @@ -67,7 +67,7 @@ ngraph::snippets::pass::InsertStore::InsertStore(const size_t count) { const auto& parent_node = input.get_source_output().get_node(); if (ov::is_type(parent_node) || ov::is_type(parent_node) || - ov::is_type(parent_node) || + ov::is_type(parent_node) || ov::is_type(parent_node)) { return false; } diff --git a/src/common/snippets/src/pass/load_movebroadcast_to_broadcastload.cpp b/src/common/snippets/src/pass/load_movebroadcast_to_broadcastload.cpp index b4fdb2506dc..7aa69d65bbd 100644 --- a/src/common/snippets/src/pass/load_movebroadcast_to_broadcastload.cpp +++ b/src/common/snippets/src/pass/load_movebroadcast_to_broadcastload.cpp @@ -24,20 +24,20 @@ ngraph::snippets::pass::LoadMoveBroadcastToBroadcastLoad::LoadMoveBroadcastToBro auto root = m.get_match_root(); const auto &pm = m.get_pattern_value_map(); - const auto input = pm.at(load_pattern).get_node_shared_ptr(); + const auto load = ov::as_type_ptr(pm.at(load_pattern).get_node_shared_ptr()); const auto param = pm.at(param_pattern).get_node_shared_ptr(); // Cannot rewrite Broadcast + Load if load has more than 1 user // or more than one input, or if Broadcast has several inputs - if (input->output(0).get_target_inputs().size() != 1 || - root->inputs().size() != 1 || input->inputs().size() != 1) { + if (load->output(0).get_target_inputs().size() != 1 || + root->inputs().size() != 1 || load->inputs().size() != 1) { return false; } auto inshape = root->input(0).get_partial_shape(); auto outshape = root->output(0).get_partial_shape(); - auto broadcastload = std::make_shared(param, outshape, ov::as_type_ptr(input)->get_offset()); + auto broadcastload = std::make_shared(param, outshape, load->get_offset()); ngraph::copy_runtime_info(root, broadcastload); ngraph::replace_node(root, broadcastload); diff --git a/src/common/snippets/src/pass/loop_fusion.cpp b/src/common/snippets/src/pass/loop_fusion.cpp index 18287d4464f..2291e074607 100644 --- a/src/common/snippets/src/pass/loop_fusion.cpp +++ b/src/common/snippets/src/pass/loop_fusion.cpp @@ -73,7 +73,6 @@ auto get_buffer_and_loop_end(const std::shared_ptr(parent_shared); if (buffer) { if (buffer->output(0).get_target_inputs().size() == 0 || - buffer->get_input_size() != 1 || buffer->get_input_source_output(0).get_target_inputs().size() != 1) return false; diff --git a/src/common/snippets/src/pass/matmul_to_brgemm.cpp b/src/common/snippets/src/pass/matmul_to_brgemm.cpp index b74fb3e68cc..add672b0fef 100644 --- a/src/common/snippets/src/pass/matmul_to_brgemm.cpp +++ b/src/common/snippets/src/pass/matmul_to_brgemm.cpp @@ -6,7 +6,7 @@ #include "snippets/pass/matmul_to_brgemm.hpp" -#include "snippets/op/brgemm.hpp" +#include "snippets/snippets_isa.hpp" #include "ngraph/opsets/opset1.hpp" #include "ngraph/rt_info.hpp" @@ -30,9 +30,13 @@ MatMulToBrgemm::MatMulToBrgemm() { return false; auto brgemm = std::make_shared(matmul->get_input_source_output(0), matmul->get_input_source_output(1)); + ov::NodeVector nodes = { brgemm }; + if (brgemm->get_output_element_type(0) != matmul->get_output_element_type(0)) { + nodes.emplace_back(std::make_shared(brgemm, matmul->get_output_element_type(0))); + } brgemm->set_friendly_name(matmul->get_friendly_name()); - ngraph::copy_runtime_info(matmul, brgemm); - ngraph::replace_node(matmul, brgemm); + ngraph::copy_runtime_info(matmul, nodes); + ngraph::replace_node(matmul, nodes.back()); return true; }; diff --git a/src/common/snippets/src/pass/reset_buffer.cpp b/src/common/snippets/src/pass/reset_buffer.cpp index bae2ac58ccd..54bdfef03f7 100644 --- a/src/common/snippets/src/pass/reset_buffer.cpp +++ b/src/common/snippets/src/pass/reset_buffer.cpp @@ -79,10 +79,9 @@ ngraph::snippets::pass::ResetBufferState::ResetBufferState() { // If after Loop there is immediately Buffer, we should reset the Buffer ptr for the next calculations for (size_t i = 0; i < o_size; ++i) { - const auto result_shape = body_shapes[i_size + i].get_shape(); // check for first target input is enough for Buffer searching because operations can have only single Buffer per each output port as op const auto consumer = loop_end->output(i).get_target_inputs().begin()->get_node(); - if (ov::is_type(consumer)) { + if (const auto buffer = ov::as_type_ptr(consumer->shared_from_this())) { // To calculate finalization offset we should know index of nesting Loop auto loop_index = 0lu; auto loop = loop_end->input_value(i).get_node_shared_ptr(); @@ -93,7 +92,8 @@ ngraph::snippets::pass::ResetBufferState::ResetBufferState() { port_idx = source_output.get_index(); loop_index++; } - + const auto result_shape = buffer->get_allocation_shape(); + NGRAPH_CHECK(loop_index < result_shape.size(), "Buffer has invalid Loop index and allocation shape rank"); const auto work_amount = std::accumulate(result_shape.rbegin(), result_shape.rbegin() + loop_index + 1, size_t(1), std::multiplies()); finalization_offsets[i_size + i] = calculate_required_finalization_offsets(work_amount, *(result_shape.rbegin() + loop_index)); diff --git a/src/common/snippets/src/pass/softmax_decomposition.cpp b/src/common/snippets/src/pass/softmax_decomposition.cpp index 8c1c79a4b54..a0259a4061b 100644 --- a/src/common/snippets/src/pass/softmax_decomposition.cpp +++ b/src/common/snippets/src/pass/softmax_decomposition.cpp @@ -126,7 +126,7 @@ ngraph::snippets::pass::SoftmaxDecomposition::SoftmaxDecomposition(const size_t apply_increments_sum, finalization_offsets_sum); const auto horizon_sum = std::make_shared(sum); - const auto buffer_exp = std::make_shared(loop_sum_end->output(0), buffer_allocation_rank); + const auto buffer_exp = std::make_shared(loop_sum_end->output(0), buffer_allocation_rank); /* =========================================== */ diff --git a/src/common/snippets/src/pass/vector_to_scalar.cpp b/src/common/snippets/src/pass/vector_to_scalar.cpp index 512a0731062..4f98a49de4e 100644 --- a/src/common/snippets/src/pass/vector_to_scalar.cpp +++ b/src/common/snippets/src/pass/vector_to_scalar.cpp @@ -24,7 +24,7 @@ ngraph::snippets::pass::SetScalarCountForLoad::SetScalarCountForLoad() { if (!load) return false; - load->set_count(1lu); + load->set_input_count(1lu, 0); return true; }); } @@ -43,7 +43,7 @@ ngraph::snippets::pass::SetScalarCountForStore::SetScalarCountForStore() { if (!store) return false; - store->set_count(1lu); + store->set_output_count(1lu, 0); return true; }); } diff --git a/src/common/snippets/src/utils.cpp b/src/common/snippets/src/utils.cpp index 3018d99e95f..6587ff93fa6 100644 --- a/src/common/snippets/src/utils.cpp +++ b/src/common/snippets/src/utils.cpp @@ -115,6 +115,17 @@ ov::PartialShape get_port_planar_shape(const Output& out) { return get_reordered_planar_shape(tensor_shape, layout); } +void set_transpose_output_layout(const ov::Output& port, const std::shared_ptr& node) { + const auto& const_order = as_type_ptr(node->get_input_node_shared_ptr(1)); + OPENVINO_ASSERT(const_order != nullptr, "Transpose order must be Constant to set layout!"); + set_output_layout(port, const_order->cast_vector()); +} + +void set_output_layout(const ov::Output& port, const std::vector& layout) { + auto& rt_info = port.get_node_shared_ptr()->get_rt_info(); + rt_info["Layout"] = layout; +} + } // namespace utils } // namespace snippets } // namespace ngraph diff --git a/src/common/snippets/tests/include/lowering_utils.hpp b/src/common/snippets/tests/include/lowering_utils.hpp index b0b1bafb245..7dfa71a4b6a 100644 --- a/src/common/snippets/tests/include/lowering_utils.hpp +++ b/src/common/snippets/tests/include/lowering_utils.hpp @@ -36,6 +36,9 @@ class DummyGenerator : public ngraph::snippets::Generator { public: DummyGenerator() : ngraph::snippets::Generator(std::make_shared()) {} DummyGenerator(const std::shared_ptr& t) : ngraph::snippets::Generator(t) {} + +protected: + opRegType get_specific_op_reg_type(const std::shared_ptr& op) const override { return vec2vec; }; }; class LoweringTests : public TransformationTestsF { diff --git a/src/common/snippets/tests/src/pass/set_scalar_count_for_load_and_store.cpp b/src/common/snippets/tests/src/pass/set_scalar_count_for_load_and_store.cpp index 9bb45a81fce..50448be3a5c 100644 --- a/src/common/snippets/tests/src/pass/set_scalar_count_for_load_and_store.cpp +++ b/src/common/snippets/tests/src/pass/set_scalar_count_for_load_and_store.cpp @@ -19,18 +19,20 @@ using namespace ngraph; // todo: Rewrite this test using Snippets test infrastructure. See ./include/canonicalization.hpp for example -template -size_t get_count(const std::shared_ptr& f, const std::string& name) { - size_t load_count = std::numeric_limits::max(); +size_t get_count(const std::shared_ptr& f, const std::string& name, bool is_load = true) { + size_t count = std::numeric_limits::max(); for (auto op : f->get_ops()) { if (op->get_friendly_name() == name) { - load_count = ov::as_type_ptr(op)->get_count(); + if (const auto memory_access = std::dynamic_pointer_cast(op)) { + count = is_load ? memory_access->get_input_offset(0) + : memory_access->get_output_offset(0); + } } } - return load_count; + return count; } -TEST(TransformationTests, SetScalarCountForLoad) { +TEST(TransformationTests, SetScalarCountForLoadStore) { std::shared_ptr f(nullptr), f_ref(nullptr); const auto count = 16; { @@ -39,11 +41,13 @@ TEST(TransformationTests, SetScalarCountForLoad) { load->set_friendly_name("load"); auto neg = std::make_shared(load); auto store = std::make_shared(neg, count); + store->set_friendly_name("store"); f = std::make_shared(NodeVector{store}, ParameterVector{data}); pass::Manager m; m.register_pass(); m.register_pass(); + m.register_pass(); m.run_passes(f); ASSERT_NO_THROW(check_rt_info(f)); } @@ -52,39 +56,6 @@ TEST(TransformationTests, SetScalarCountForLoad) { auto load = std::make_shared(data, 1lu); load->set_friendly_name("load_ref"); auto neg = std::make_shared(load); - auto store = std::make_shared(neg, count); - f_ref = std::make_shared(NodeVector{store}, ParameterVector{data}); - } - - auto res = compare_functions(f, f_ref); - ASSERT_TRUE(res.first) << res.second; - - auto load_count = get_count(f, "load"); - auto load_count_ref = get_count(f_ref, "load_ref"); - ASSERT_EQ(load_count, load_count_ref); -} - -TEST(TransformationTests, SetScalarCountForStore) { - std::shared_ptr f(nullptr), f_ref(nullptr); - const auto count = 16; - { - auto data = std::make_shared(element::f32, Shape{2, 2}); - auto load = std::make_shared(data, count); - auto neg = std::make_shared(load); - auto store = std::make_shared(neg, count); - store->set_friendly_name("store"); - f = std::make_shared(NodeVector{store}, ParameterVector{data}); - - pass::Manager m; - m.register_pass(); - m.register_pass(); - m.run_passes(f); - ASSERT_NO_THROW(check_rt_info(f)); - } - { - auto data = std::make_shared(element::f32, Shape{2, 2}); - auto load = std::make_shared(data, count); - auto neg = std::make_shared(load); auto store = std::make_shared(neg, 1lu); store->set_friendly_name("store_ref"); f_ref = std::make_shared(NodeVector{store}, ParameterVector{data}); @@ -93,7 +64,11 @@ TEST(TransformationTests, SetScalarCountForStore) { auto res = compare_functions(f, f_ref); ASSERT_TRUE(res.first) << res.second; - int64_t store_count = get_count(f, "store"); - int64_t store_count_ref = get_count(f_ref, "store_ref"); + auto load_count = get_count(f, "load"); + auto load_count_ref = get_count(f_ref, "load_ref"); + ASSERT_EQ(load_count, load_count_ref); + + auto store_count = get_count(f, "store", false); + auto store_count_ref = get_count(f_ref, "store_ref", false); ASSERT_EQ(store_count, store_count_ref); } diff --git a/src/common/snippets/tests/src/registers.cpp b/src/common/snippets/tests/src/registers.cpp index 531190e6048..e9d7c503802 100644 --- a/src/common/snippets/tests/src/registers.cpp +++ b/src/common/snippets/tests/src/registers.cpp @@ -13,6 +13,7 @@ #include #include "common_test_utils/ngraph_test_utils.hpp" +#include "lowering_utils.hpp" using namespace testing; using namespace ngraph; @@ -20,6 +21,7 @@ using namespace ngraph; // todo: Rewrite this test using Snippets test infrastructure. See ./include/canonicalization.hpp for example TEST(TransformationTests, AssignRegisters) { + const auto generator = std::make_shared(); std::shared_ptr f(nullptr); { auto p0 = std::make_shared(element::f32, Shape(1)); @@ -37,7 +39,12 @@ TEST(TransformationTests, AssignRegisters) { pass::Manager m; m.register_pass(); - m.register_pass(); + std::function& op)> reg_type_mapper = + [=](const std::shared_ptr& op) -> snippets::Generator::opRegType { + return generator->get_op_reg_type(op); + }; + m.register_pass(reg_type_mapper); + m.run_passes(f); ASSERT_NO_THROW(check_rt_info(f)); } @@ -73,6 +80,7 @@ TEST(TransformationTests, AssignRegisters) { } TEST(TransformationTests, AssignRegisters2) { + const auto generator = std::make_shared(); std::shared_ptr f(nullptr); { auto p0 = std::make_shared(ngraph::element::f32, Shape()); @@ -126,7 +134,11 @@ TEST(TransformationTests, AssignRegisters2) { pass::Manager m; m.register_pass(); - m.register_pass(); + std::function& op)> reg_type_mapper = + [=](const std::shared_ptr& op) -> snippets::Generator::opRegType { + return generator->get_op_reg_type(op); + }; + m.register_pass(reg_type_mapper); m.run_passes(f); ASSERT_NO_THROW(check_rt_info(f)); } diff --git a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp index 8c2e666d6b6..3841a768d42 100644 --- a/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp +++ b/src/plugins/intel_cpu/src/emitters/cpu_generator.cpp @@ -18,7 +18,8 @@ #include "snippets_transformations/op/load_convert.hpp" #include "snippets_transformations/op/store_convert.hpp" #include "snippets_transformations/op/fused_mul_add.hpp" -#include "snippets/op/brgemm.hpp" +#include "snippets_transformations/op/brgemm_copy_b.hpp" +#include "snippets_transformations/op/brgemm_cpu.hpp" #include "ngraph_transformations/op/swish_cpu.hpp" #include @@ -144,7 +145,8 @@ ov::intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_ jitters[ngraph::snippets::op::Kernel::get_type_info_static()] = CREATE_EMITTER(KernelEmitter); jitters[ngraph::snippets::op::LoopBegin::get_type_info_static()] = CREATE_EMITTER(LoopBeginEmitter); jitters[ngraph::snippets::op::LoopEnd::get_type_info_static()] = CREATE_EMITTER(LoopEndEmitter); - jitters[ngraph::snippets::op::Brgemm::get_type_info_static()] = CREATE_EMITTER(BrgemmEmitter); + jitters[ov::intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_EMITTER(BrgemmEmitter); + jitters[ov::intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_EMITTER(BrgemmCopyBEmitter); } size_t ov::intel_cpu::CPUTargetMachine::get_lanes() const { @@ -169,3 +171,15 @@ code ov::intel_cpu::CPUTargetMachine::get_snippet() const { ov::intel_cpu::CPUGenerator::CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa_) : Generator(std::make_shared(isa_)) { } + +ngraph::snippets::Generator::opRegType ov::intel_cpu::CPUGenerator::get_specific_op_reg_type(const std::shared_ptr& op) const { + if (std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return gpr2gpr; + else if ( + std::dynamic_pointer_cast(op) || + std::dynamic_pointer_cast(op)) + return vec2vec; + else + throw ov::Exception("Register type of the operation " + std::string(op->get_type_name()) + " isn't determined!"); +} diff --git a/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp b/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp index 7301fcb177b..b624d2c0b09 100644 --- a/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp +++ b/src/plugins/intel_cpu/src/emitters/cpu_generator.hpp @@ -28,6 +28,9 @@ private: class CPUGenerator : public ngraph::snippets::Generator { public: CPUGenerator(dnnl::impl::cpu::x64::cpu_isa_t isa); + +protected: + opRegType get_specific_op_reg_type(const std::shared_ptr& op) const override; }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp index 4f63dd641f6..338cb62dcec 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.cpp @@ -6,9 +6,10 @@ #include #include "jit_snippets_emitters.hpp" -#include "snippets/op/brgemm.hpp" #include "snippets/op/subgraph.hpp" #include "snippets/utils.hpp" +#include "snippets_transformations/op/brgemm_copy_b.hpp" +#include "snippets_transformations/op/brgemm_cpu.hpp" using namespace InferenceEngine; using ngraph::snippets::op::Subgraph; @@ -20,6 +21,10 @@ using namespace dnnl::impl::cpu::x64; namespace ov { namespace intel_cpu { +namespace { +constexpr size_t gpr_size = 8; +} // namespace + inline static void transform_idxs_to_regs(const std::vector& idxs, std::vector& regs) { regs.resize(idxs.size()); std::transform(idxs.begin(), idxs.end(), regs.begin(), [](size_t idx){return Reg64(static_cast(idx));}); @@ -68,7 +73,8 @@ void jit_container_emitter::map_abstract_registers(mapping_info& gpr_map_pool, // where all utility emitters align with conventional Op emitters if (std::dynamic_pointer_cast(emitter) || std::dynamic_pointer_cast(emitter) || - std::dynamic_pointer_cast(emitter)) + std::dynamic_pointer_cast(emitter) || + std::dynamic_pointer_cast(emitter)) in_physical_regs = map_regs(in_abstract_regs, gpr_map_pool); else in_physical_regs = std::move(in_abstract_regs); @@ -182,7 +188,8 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: // todo: how this will be handled if Brgemm in & out are op::Buffer // Brgemm is a special case since it incorporates input and output (we use onednn kernel) // Just like Load & Store it requires offsets calculation - const auto is_brgemm = std::dynamic_pointer_cast(emitter) != nullptr; + const auto is_brgemm = std::dynamic_pointer_cast(emitter) || + std::dynamic_pointer_cast(emitter); return emitter_type == gpr_to_vec || emitter_type == vec_to_gpr || is_brgemm; }); // Note that we can't use reg_indexes_idx or reg_const_params_idx to store data pointers because these two @@ -567,9 +574,6 @@ LoadEmitter::LoadEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu IE_THROW() << "LoadEmitter supports only equal input and output types but gets: " << src_prc.name() << " and " << dst_prc.name(); const auto load = std::dynamic_pointer_cast(n); - if (!load) - IE_THROW() << "LoadEmitter expects Load snippets op"; - count = load->get_count(); byte_offset = load->get_offset(); in_out_type_ = emitter_in_out_map::gpr_to_vec; @@ -606,9 +610,6 @@ BroadcastLoadEmitter::BroadcastLoadEmitter(dnnl::impl::cpu::x64::jit_generator* IE_THROW() << "BroadcastEmitters support only equal input and output types but gets: " << src_prc.name() << " and " << dst_prc.name(); const auto broadcast_load = std::dynamic_pointer_cast(n); - if (!broadcast_load) - IE_THROW() << "BroadcastLoadEmitter expects BroadcastLoad snippets op"; - byte_offset = broadcast_load->get_offset(); in_out_type_ = emitter_in_out_map::gpr_to_vec; } @@ -717,12 +718,15 @@ size_t BrgemmEmitter::getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const { return mIdx * 4 + kIdx * 2 + nIdx; } BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, - const std::shared_ptr& node) : jit_emitter(h, isa, node) { + const std::shared_ptr& node) : jit_emitter(h, isa, node) { in_out_type_ = emitter_in_out_map::gpr_to_gpr; - const auto& brgemm_node = as_type_ptr(node); + const auto& brgemm_node = as_type_ptr(node); if (brgemm_node->is_dynamic()) IE_THROW() << "Snippets don't support code generation for dynamic Brgemm"; - const OutputVector io_values {brgemm_node->input_value(0), brgemm_node->input_value(1), brgemm_node->output(0)}; + const auto brgemm_copy = brgemm_node->is_with_data_repacking() ? brgemm_node->get_brgemm_copy() : nullptr; + const OutputVector io_values {brgemm_node->input_value(0), + brgemm_copy ? brgemm_copy->input_value(0) : brgemm_node->input_value(1), + brgemm_node->output(0)}; std::vector leading_dimensions; std::vector> io_layouts; for (const auto& val : io_values) { @@ -747,51 +751,61 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: io_layouts.push_back(layout); } } - // todo: leave AMX and VNNI related code for now, it'll help to enable int8 and bf16 support - bool isAMXSupported = mayiuse(avx512_core_amx); const auto& A_shape = io_values[0].get_shape(); const auto& A_layout = io_layouts[0]; const auto& C_shape = io_values[2].get_shape(); const auto& C_layout = io_layouts[2]; - M = C_shape[C_layout[2]]; - K = A_shape[A_layout[3]]; - M_blk = matmulOptimalM; - M_tail = M % M_blk; + // We need find original M,N,K having layouts and ordered shapes + // Layout: 0, 1, 2, 3 => New layout: 0, 2, 1, 3 + // Shape: 1, 3, 5, 9 => New Shape: 1, 5, 3, 9 + // To find original 2nd dimension, we should find index of position value `2` in new layout + // and get dimension from new shape by this index + auto get_ordered_idx = [](const std::vector& layout, size_t idx) { + return std::distance(layout.begin(), std::find(layout.begin(), layout.end(), idx)); + }; + + m_M = C_shape[get_ordered_idx(C_layout, C_layout.size() - 2)]; + m_K = A_shape[get_ordered_idx(A_layout, A_layout.size() - 1)]; + m_M_blk = matmulOptimalM; + m_M_tail = m_M % m_M_blk; // B_shape[B_layout[3]] - N = C_shape[C_layout[3]]; + m_N = C_shape[get_ordered_idx(C_layout, C_layout.size() - 1)]; auto brg0Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(0)); auto brg1Prc = InferenceEngine::details::convertPrecision(brgemm_node->get_input_element_type(1)); io_data_size = {brg0Prc.size(), brg1Prc.size(), brgemm_node->get_output_element_type(0).size()}; - brg0VnniFactor = 4 / brg0Prc.size(); - bool brg0WithAMX = isAMXSupported && brg0Prc != Precision::FP32 && (K % brg0VnniFactor == 0) && (N % brg0VnniFactor == 0); + m_brg0VnniFactor = 4 / brg0Prc.size(); + bool brgWithAMX = brgemm_node->is_amx(); - N_blk = brg0Prc == Precision::FP32 ? N : - brg0Prc == Precision::BF16 ? 32 : 64; - N_tail = N % N_blk; - K_blk = brg0WithAMX ? brg0Prc == Precision::BF16 ? 32 : 64 - : K; - K_tail = K % K_blk; + m_with_comp = brgemm_node->is_with_compensations(); + m_with_scratch = brgemm_node->is_with_scratchpad(); + + m_N_blk = brg1Prc == Precision::FP32 ? m_N : + brg1Prc == Precision::BF16 ? 32 : 64; + m_N_tail = m_N % m_N_blk; + m_K_blk = brgWithAMX ? brg0Prc == Precision::BF16 ? 32 : 64 + : m_K; + m_K_tail = m_K % m_K_blk; size_t brg0BaseIdx = -1; for (size_t m = 0; m < 2; m++) { for (size_t k = 0; k < 2; k++) { for (size_t n = 0; n < 2; n++) { - auto& brgemmCtx = brgCtxs0[getBrgIdx(m, k, n)]; + auto& brgemmCtx = m_brgCtxs0[getBrgIdx(m, k, n)]; - auto M_ = m ? M_tail - : M < M_blk ? 0 : M_blk; - auto N_ = n ? N_tail : N - N_tail; - auto K_ = k ? K_tail : K - K_tail; - auto beta = k && brgCtxs0[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f; + auto M_ = m ? m_M_tail + : m_M < m_M_blk ? 0 : m_M_blk; + auto N_ = n ? m_N_tail : m_N - m_N_tail; + auto K_ = k ? m_K_tail : m_K - m_K_tail; + auto beta = k && m_brgCtxs0[getBrgIdx(m, 0, n)].K != 0 ? 1.0f : 0.0f; brgemmCtx.M = M_; brgemmCtx.N = N_; brgemmCtx.K = K_; brgemmCtx.LDA = leading_dimensions[0]; - brgemmCtx.LDB = leading_dimensions[1]; + brgemmCtx.LDB = brgemm_node->is_with_data_repacking() ? rnd_up(m_N, m_N_blk) : leading_dimensions[1]; brgemmCtx.LDC = leading_dimensions[2]; brgemmCtx.dt_in0 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg0Prc)); brgemmCtx.dt_in1 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(brg1Prc)); @@ -801,22 +815,46 @@ BrgemmEmitter::BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl: if (M_ != 0 && K_ != 0 && N_ != 0) { if (brg0BaseIdx == -1) brg0BaseIdx = getBrgIdx(m, k, n); - initBrgemm(brgemmCtx, brgKernels0[getBrgIdx(m, k, n)], brg0WithAMX); + initBrgemm(brgemmCtx, m_brgKernels0[getBrgIdx(m, k, n)], brgWithAMX); } } } } - load_offset_a = brgemm_node->get_offset_a(); - load_offset_b = brgemm_node->get_offset_b(); - store_offset_c = brgemm_node->get_offset_c(); + m_load_offset_a = brgemm_node->get_offset_a(); + m_load_offset_b = brgemm_node->get_offset_b(); + m_store_offset_c = brgemm_node->get_offset_c(); + if (m_with_scratch) + m_load_offset_scratch = brgemm_node->get_offset_scratch(); +} + +std::set> BrgemmEmitter::get_supported_precisions(const std::shared_ptr& node) { + const auto brgemm = as_type_ptr(node); + OPENVINO_ASSERT(brgemm, "BrgemmEmitter::get_supported_precisions() expects BrgemmCPU node"); + switch (brgemm->get_type()) { + case BrgemmCPU::Type::Floating: + return {{element::f32, element::f32}}; + case BrgemmCPU::Type::WithDataRepacking: + return {{element::u8, element::i8}, + {element::bf16, element::bf16}}; + case BrgemmCPU::Type::WithCompensations: + return {{element::i8, element::i8, element::f32}}; + case BrgemmCPU::Type::AMX: + return {{element::i8, element::i8, element::u8}, + {element::u8, element::i8, element::u8}, + {element::bf16, element::bf16, element::u8}}; + default: + throw ov::Exception("BrgemmEmitter got BrgemmCPU node with unsupported type"); + } } void BrgemmEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx) const { brgemm_t brgDesc; brgemm_strides_t strides {static_cast(ctx.M * ctx.K), static_cast(ctx.K * ctx.N)}; - // When implementing int8 support, note that isa logics is more complicated in the MHA node - auto status = brgemm_desc_init(&brgDesc, host_isa_, brgemm_strd, ctx.dt_in0, ctx.dt_in1, + const bool is_int8 = utils::one_of(ctx.dt_in0, data_type::u8, data_type::s8) && utils::one_of(ctx.dt_in1, data_type::u8, data_type::s8); + auto isa = use_amx ? isa_undef + : ctx.dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : (is_int8 ? avx512_core_vnni : avx512_core); + auto status = brgemm_desc_init(&brgDesc, isa, brgemm_strd, ctx.dt_in0, ctx.dt_in1, false, false, brgemm_row_major, 1.f, ctx.beta, ctx.LDA, ctx.LDB, ctx.LDC, ctx.M, ctx.N, ctx.K, &strides); if (status != dnnl_success) IE_THROW() << "BrgemmEmitter cannot initialize brgemm descriptor due to invalid params"; @@ -837,23 +875,91 @@ void BrgemmEmitter::initBrgemm(brgemmCtx& ctx, std::unique_ptr& void BrgemmEmitter::emit_impl(const std::vector& in, const std::vector& out) const { - if (host_isa_ == cpu::x64::sse41 || host_isa_ == cpu::x64::avx2) { - IE_THROW() << "BrgemmEmitter requires at least avx512_core instruction set"; - } else if (host_isa_ == cpu::x64::avx512_core) { - emit_isa(in, out); + if (host_isa_ == cpu::x64::avx512_core) { + Xbyak::Reg64 input_0(static_cast(in[0])); + Xbyak::Reg64 input_1(static_cast(in[1])); + Xbyak::Reg64 input_2(static_cast(0)); // scratch. Default reg index is 0 if there isn't scratch + if (m_with_scratch) { + if (in.size() != 3) { + IE_THROW() << "BRGEMM Emitter expects 3 inputs if there are compensations/wsp"; + } + input_2 = Xbyak::Reg64(static_cast(in[2])); + } + Xbyak::Reg64 output_0(static_cast(out[0])); + + for (size_t mb = 0; mb < div_up(m_M, m_M_blk); mb++) { + const bool is_M_tail = (m_M - mb * m_M_blk < m_M_blk); + + size_t brgIdx0 = getBrgIdx(0, 0, 0); + size_t K0_step0 = m_brgCtxs0[brgIdx0].K; + size_t K0_step1 = m_brgCtxs0[brgIdx0].K * m_brgCtxs0[brgIdx0].LDB; + size_t N0_step0 = m_brgCtxs0[brgIdx0].N * m_brg0VnniFactor; + size_t N0_step1 = m_brgCtxs0[brgIdx0].N; + for (size_t n = 0; n < 2; n++) { + for (size_t k = 0; k < 2; k++) { + size_t mIdx = is_M_tail ? 1 : 0; + auto& brgemmCtx = m_brgCtxs0[getBrgIdx(mIdx, k, n)]; + + if (brgemmCtx.K != 0 && brgemmCtx.N != 0) { + const size_t in0_offset = m_load_offset_a + (k * K0_step0 + mb * m_M_blk * brgemmCtx.LDA) * io_data_size[0]; + const size_t in1_offset = m_load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1]; + const size_t in2_offset = m_load_offset_scratch + (m_with_comp ? n * N0_step1 * sizeof(int32_t) : 0); + const size_t out0_offset = m_store_offset_c + (n * N0_step1 + mb * m_M_blk * brgemmCtx.LDC) * io_data_size[2]; + + emit_brgemm_kernel_call(m_brgKernels0[getBrgIdx(mIdx, k, n)].get(), + brgemmCtx, + input_0, + input_1, + input_2, + output_0, + in0_offset, + in1_offset, + in2_offset, + out0_offset); + } + } + } + } } else { - assert(!"unsupported isa"); + IE_THROW() << "BrgemmEmitter requires at least avx512_core instruction set"; } } -template -void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, int bs, - Reg64 addr_A, Reg64 addr_B, - const brgemm_batch_element_t *batch, Reg64 addr_C, void *scratch, - const size_t in0_kernel_offset, const size_t in1_kernel_offset, const size_t out0_kernel_offset) const { - using Vmm = typename dnnl::impl::utils::conditional3::type; - size_t gpr_size = 8; - Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax, - h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; + +void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brg_kernel, const brgemmCtx& ctx, + Reg64 addr_A, Reg64 addr_B, Reg64 scratch, Reg64 addr_C, + const size_t in0_kernel_offset, const size_t in1_kernel_offset, + const size_t in2_kernel_offset, const size_t out0_kernel_offset) const { + if (ctx.is_with_amx) { + Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->rax, + h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; + size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); + + h->sub(h->rsp, n_gprs_to_save * gpr_size); + for (size_t i = 0; i < n_gprs_to_save; ++i) + h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]); + + // save function address in gpr to pass in call instruction + const auto& overload = static_cast(amx_tile_configure); + h->mov(h->rbp, reinterpret_cast(overload)); + h->mov(abi_param1, reinterpret_cast(ctx.palette)); + + // align stack on 16-byte as ABI requires + // note that RBX must not be changed by the callee + h->mov(h->rbx, h->rsp); + h->and_(h->rbx, 0xf); + h->sub(h->rsp, h->rbx); + + h->call(h->rbp); + + h->add(h->rsp, h->rbx); + // restore gpr registers + for (int i = n_gprs_to_save - 1; i >= 0; --i) + h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]); + h->add(h->rsp, n_gprs_to_save * gpr_size); + } + + Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, + h->rax, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); h->sub(h->rsp, n_gprs_to_save * gpr_size); @@ -862,14 +968,12 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in // caller obligation to save k-regs as callee may use them size_t n_k_regs_to_save = 8; - if (isa == cpu::x64::avx512_core) { - h->sub(h->rsp, n_k_regs_to_save * k_mask_size); - for (size_t i = 0; i < n_k_regs_to_save; ++i) { - if (mayiuse(avx512_core)) - h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); - else - h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); - } + h->sub(h->rsp, n_k_regs_to_save * k_mask_size); + for (size_t i = 0; i < n_k_regs_to_save; ++i) { + if (mayiuse(avx512_core)) + h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); + else + h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); } // 1. Caller obligation to save vector registers as callee may use them. @@ -879,13 +983,16 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in // `host_isa::vecs_count`. h->sub(h->rsp, get_max_vecs_count() * get_vec_length()); for (size_t i = 0; i < get_max_vecs_count(); ++i) - h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Vmm(i)); + h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Zmm(i)); + size_t num_args_passed_on_stack = 0; // save function address in gpr to pass in call instruction const auto& brgemm_kernel_overload = static_cast(kernel_execute); + void*, + void*, + int)>(kernel_execute); h->mov(h->rbp, reinterpret_cast(brgemm_kernel_overload)); // todo: several of addr_{A, B, C} could be also abi_paramX, so one of them could be corrupted // if moving directly h->uni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption. @@ -893,16 +1000,44 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in h->uni_vmovq(Xmm(0), addr_A); h->uni_vmovq(Xmm(1), addr_B); h->uni_vmovq(Xmm(2), addr_C); - + if (m_with_scratch) + h->uni_vmovq(Xmm(3), scratch); + // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. const auto data_ptr_reg = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) { h->uni_vmovq(reg, xmm); if (bytes_offset) h->add(reg, bytes_offset); }; - h->mov(abi_param1, reinterpret_cast(brgKernel)); + h->mov(abi_param1, reinterpret_cast(brg_kernel)); data_ptr_reg(Xmm(0), abi_param2, in0_kernel_offset); data_ptr_reg(Xmm(1), abi_param3, in1_kernel_offset); data_ptr_reg(Xmm(2), abi_param4, out0_kernel_offset); +#ifdef _WIN32 + // Before function call we should allocate stack area for + // - register parameters - ABI parameters (shadow space) + // - stack parameters - remaining parameters + num_args_passed_on_stack = 6; // count of function brgemm_kernel_overload() parameters + size_t abi_param_count = sizeof(abi_param_regs) / sizeof(abi_param_regs[0]); + h->sub(h->rsp, num_args_passed_on_stack * gpr_size); + + // Push the remaining parameters on the stack + if (m_with_scratch) { + h->uni_vmovq(h->qword[h->rsp + (abi_param_count + 0) * gpr_size], Xmm(3)); + if (in2_kernel_offset) h->add(h->qword[h->rsp + (abi_param_count + 0) * gpr_size], in2_kernel_offset); + } else { + h->mov(h->qword[h->rsp + (abi_param_count + 0) * gpr_size], reinterpret_cast(nullptr)); + } + h->mov(abi_not_param1, static_cast(m_with_comp)); + h->mov(h->qword[h->rsp + (abi_param_count + 1) * gpr_size], abi_not_param1); +#else + if (m_with_scratch) { + data_ptr_reg(Xmm(3), abi_param5, in2_kernel_offset); + } else { + h->mov(abi_param5, reinterpret_cast(nullptr)); + } + h->mov(abi_param6, static_cast(m_with_comp)); +#endif + // align stack on 16-byte as ABI requires // note that RBX must not be changed by the callee h->mov(h->rbx, h->rsp); @@ -912,22 +1047,22 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in h->call(h->rbp); h->add(h->rsp, h->rbx); + if (num_args_passed_on_stack > 0) + h->add(h->rsp, num_args_passed_on_stack * gpr_size); // restore vector registers for (int i = static_cast(get_max_vecs_count()) - 1; i >= 0; --i) { - h->uni_vmovups(Vmm(i), h->ptr[h->rsp + i * get_vec_length()]); + h->uni_vmovups(Zmm(i), h->ptr[h->rsp + i * get_vec_length()]); } h->add(h->rsp, (get_max_vecs_count()) * get_vec_length()); // restore k registers - if (isa == cpu::x64::avx512_core) { - for (int i = n_k_regs_to_save - 1; i >= 0; --i) { - if (mayiuse(avx512_core)) - h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); - else - h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); - } - h->add(h->rsp, n_k_regs_to_save * k_mask_size); + for (int i = n_k_regs_to_save - 1; i >= 0; --i) { + if (mayiuse(avx512_core)) + h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); + else + h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); } + h->add(h->rsp, n_k_regs_to_save * k_mask_size); // restore gpr registers for (int i = n_gprs_to_save - 1; i >= 0; --i) @@ -935,9 +1070,8 @@ void BrgemmEmitter::emit_brgemm_kernel_call(const brgemm_kernel_t *brgKernel, in h->add(h->rsp, n_gprs_to_save * gpr_size); } -void BrgemmEmitter::kernel_execute(const brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C) { - // TODO: There are 4 available abi_params on Windows so we have the copy of brgemm_kernel_execute() function - // with 4 runtime parameters (kernel and I/O) and 4 default parameter values (batch, bs and scratch) +void BrgemmEmitter::kernel_execute(const brgemm_kernel_t *brg_kernel, + const void *A, const void *B, void *C, void *scratch, int with_comp) { brgemm_kernel_params_t brgemm_p; brgemm_p.batch = nullptr; // default value @@ -945,54 +1079,266 @@ void BrgemmEmitter::kernel_execute(const brgemm_kernel_t *brg_kernel, const void brgemm_p.ptr_B = B; brgemm_p.ptr_C = C; brgemm_p.ptr_D = C; - brgemm_p.ptr_buf = nullptr; // default value + brgemm_p.ptr_buf = scratch; brgemm_p.ptr_bias = nullptr; - brgemm_p.do_post_ops = 0; - brgemm_p.do_apply_comp = 0; + brgemm_p.do_post_ops = static_cast(with_comp); + brgemm_p.do_apply_comp = static_cast(with_comp); brgemm_p.skip_accm = 0; brgemm_p.BS = 1; // default value assert(brg_kernel); (*brg_kernel)(&brgemm_p); } -template -void BrgemmEmitter::emit_isa(const std::vector &in, const std::vector &out) const { - Reg64 input_0(static_cast(in[0])); - Reg64 input_1(static_cast(in[1])); - Reg64 output_0(static_cast(out[0])); +BrgemmCopyBEmitter::BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) + : jit_emitter(h, isa, n) { + in_out_type_ = emitter_in_out_map::gpr_to_gpr; + const auto brgemm_repack = ov::as_type_ptr(n); + if (!brgemm_repack) + IE_THROW() << "BrgemmCopyBEmitters expects BrgemmCopyB node"; - for (size_t mb = 0; mb < div_up(M, M_blk); mb++) { - const bool is_M_tail = (M - mb * M_blk < M_blk); + m_brgemm_prc_in0 = brgemm_repack->get_src_element_type(); + m_brgemm_prc_in1 = brgemm_repack->get_input_element_type(0); + m_brgemmVNNIFactor = 4 / m_brgemm_prc_in0.size(); + m_with_comp = brgemm_repack->is_with_compensations(); + m_in_offset = brgemm_repack->get_offset_in(); + m_out_offset = brgemm_repack->get_offset_out(); + if (m_with_comp) + m_comp_offset = brgemm_repack->get_offset_compensations(); - size_t brgIdx0 = getBrgIdx(0, 0, 0); - size_t K0_step0 = brgCtxs0[brgIdx0].K; - size_t K0_step1 = brgCtxs0[brgIdx0].K * brgCtxs0[brgIdx0].LDB; - size_t N0_step0 = brgCtxs0[brgIdx0].N * brg0VnniFactor; - size_t N0_step1 = brgCtxs0[brgIdx0].N; - for (size_t n = 0; n < 2; n++) { - for (size_t k = 0; k < 2; k++) { - size_t mIdx = is_M_tail ? 1 : 0; - auto& brgemmCtx = brgCtxs0[getBrgIdx(mIdx, k, n)]; - - if (brgemmCtx.K != 0 && brgemmCtx.N != 0) { - const size_t in0_offset = load_offset_a + (k * K0_step0 + mb * M_blk * brgemmCtx.LDA) * io_data_size[0]; - const size_t in1_offset = load_offset_b + (k * K0_step1 + n * N0_step0) * io_data_size[1]; - const size_t out0_offset = store_offset_c + (n * N0_step1 + mb * M_blk * brgemmCtx.LDC) * io_data_size[2]; - - emit_brgemm_kernel_call(brgKernels0[getBrgIdx(mIdx, k, n)].get(), - 1, - input_0, - input_1, - nullptr, - output_0, - nullptr, - in0_offset, - in1_offset, - out0_offset); - } - } + auto layout = ngraph::snippets::utils::get_node_output_layout(brgemm_repack->get_input_node_shared_ptr(0)); + const auto& original_shape = brgemm_repack->get_input_shape(0); + auto transposed_shape = original_shape; + size_t leading_dimension = *(original_shape.rbegin()); + if (!layout.empty()) { + transposed_shape.resize(layout.size(), 1); + for (size_t i = 0; i < layout.size(); ++i) { + transposed_shape[i] = original_shape[layout[i]]; } + // The idea here is to find "2" (for 4D shapes) in the layout and multiply dimensions that are to the right + // This implies that "3" is the last layout value, otherwise this layout is not supported. + // counting from the end since shape could be prepended with ones + const int64_t num_last_dims = layout.end() - std::find(layout.begin(), layout.end(), layout.size() - 2) - 1; + if (layout.back() != layout.size() - 1 || num_last_dims < 1) + IE_THROW() << "BrgemmRepackEmitter detected invalid layout values: " << + "check that this shape + layout combination is schedulable"; + leading_dimension = std::accumulate(original_shape.end() - num_last_dims, original_shape.end(), 1, std::multiplies()); } + + m_N = *(transposed_shape.rbegin()); + m_K = *(transposed_shape.rbegin() + 1); + + const bool isAMXSupported = mayiuse(avx512_core_amx); + const auto use_amx = isAMXSupported && m_brgemm_prc_in0 != ov::element::f32 && (m_K % m_brgemmVNNIFactor == 0) && (m_N % m_brgemmVNNIFactor == 0); + + m_N_blk = m_brgemm_prc_in1 == ov::element::f32 ? m_N : + m_brgemm_prc_in1 == ov::element::bf16 ? 32 : 64; + m_K_blk = use_amx ? m_brgemm_prc_in0 == ov::element::bf16 ? 32 : 64 + : m_K; + m_N_tail = m_N % m_N_blk; + m_K_tail = m_K % m_K_blk; + m_LDB = m_brgemm_prc_in1 == ov::element::f32 ? leading_dimension : rnd_up(m_N, m_N_blk); + + const auto dt_in0 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(m_brgemm_prc_in0))); + const auto dt_in1 = static_cast(DnnlExtensionUtils::IEPrecisionToDataType(InferenceEngine::details::convertPrecision(m_brgemm_prc_in1))); + init_brgemm_copy(m_kernel, leading_dimension, m_N_blk, m_N_tail, m_LDB, m_K - m_K_tail, use_amx, dt_in0, dt_in1); +} + +void BrgemmCopyBEmitter::init_brgemm_copy(std::unique_ptr& kernel, + size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, + bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const { + matmul::brgemm_matmul_conf_t brgCopyKernelConf; + brgCopyKernelConf.src_dt = dt_in0; + brgCopyKernelConf.wei_dt = dt_in1; + brgCopyKernelConf.wei_n_blk = static_cast(N_blk); + brgCopyKernelConf.wei_tag = dnnl_abcd; // What's about other ranks? + brgCopyKernelConf.copy_B_wei_stride = 0; + brgCopyKernelConf.LDB = static_cast(LDB); + brgCopyKernelConf.N = static_cast(N); + brgCopyKernelConf.N_tail = static_cast(N_tail); + brgCopyKernelConf.N_blk = static_cast(N_blk); + brgCopyKernelConf.K = static_cast(K); + brgCopyKernelConf.K_blk = static_cast(K); + brgCopyKernelConf.N_chunk_elems = brgCopyKernelConf.N_blk; + brgCopyKernelConf.b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); + brgCopyKernelConf.tr_b_dt_sz = DnnlExtensionUtils::sizeOfDataType(static_cast(brgCopyKernelConf.src_dt)); + brgCopyKernelConf.req_wei_vnni_downconvert = false; + + if (is_with_amx) { + brgCopyKernelConf.isa = avx512_core_amx; + brgCopyKernelConf.s8s8_compensation_required = false; + } else { + brgCopyKernelConf.isa = dt_in0 == dnnl_data_type_t::dnnl_bf16 ? avx512_core_bf16 : avx512_core_vnni; + brgCopyKernelConf.s8s8_compensation_required = dt_in0 == dnnl_data_type_t::dnnl_s8; + } + + brgCopyKernelConf.has_zero_point_a = false; + brgCopyKernelConf.has_zero_point_b = false; + brgCopyKernelConf.src_zp_type = dnnl::impl::cpu::x64::none; + + auto status = matmul::create_brgemm_matmul_copy_b(kernel, &brgCopyKernelConf); + if (status != dnnl_success) + IE_THROW() << "BrgemmRepackEmitter cannot create kernel due to invalid params"; +} + +void BrgemmCopyBEmitter::emit_impl(const std::vector& in, + const std::vector& out) const { + if (host_isa_ == cpu::x64::avx512_core) { + Xbyak::Reg64 src(static_cast(in[0])); + Xbyak::Reg64 dst(static_cast(out[0])); + Xbyak::Reg64 comp(static_cast(0)); // Compensations. Default reg idx is 0 if there aren't the compensations + if (m_with_comp) { + if (out.size() != 2) { + IE_THROW() << "BrgemmCopyBEmitter with compensations requires separate register for them"; + } + comp = Xbyak::Reg64(static_cast(out[1])); + } + + const size_t data_size = m_brgemm_prc_in1.size(); + for (size_t nb = 0; nb < div_up(m_N, m_N_blk); nb++) { + const size_t offset_in = m_in_offset + nb * m_N_blk * data_size; + const size_t offset_out = m_out_offset + nb * m_N_blk * m_brgemmVNNIFactor * data_size; + const size_t offset_comp = m_with_comp ? m_comp_offset + nb * m_N_blk * sizeof(int32_t) : 0; + + const bool is_N_tail = (m_N - nb * m_N_blk < m_N_blk); + const auto current_N_blk = is_N_tail ? m_N_tail : m_N_blk; + + emit_kernel_call(m_kernel.get(), src, dst, comp, current_N_blk, m_K, offset_in, offset_out, offset_comp); + } + } else { + IE_THROW() << "BrgemmCopyBEmitter requires at least avx512_core instruction set"; + } +} + +void BrgemmCopyBEmitter::emit_kernel_call(const matmul::jit_brgemm_matmul_copy_b_t* kernel, Reg64 src, Reg64 dst, Reg64 comp, + size_t N, size_t K, size_t offset_in, size_t offset_out, size_t offset_comp) const { + Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15, + h->rax, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp, h->rbx}; + size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]); + + h->sub(h->rsp, n_gprs_to_save * gpr_size); + for (size_t i = 0; i < n_gprs_to_save; ++i) + h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]); + + // caller obligation to save k-regs as callee may use them + size_t n_k_regs_to_save = 8; + h->sub(h->rsp, n_k_regs_to_save * k_mask_size); + for (size_t i = 0; i < n_k_regs_to_save; ++i) { + if (mayiuse(avx512_core)) + h->kmovq(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); + else + h->kmovw(h->ptr[h->rsp + i * k_mask_size], Opmask(static_cast(i))); + } + + // 1. Caller obligation to save vector registers as callee may use them. + // 2. There is an implicit assumption that the host code uses the same + // `isa` as the injector. Once the assumption is wrong, `vecs_count` and + // `vlen` should be replaced with `host_isa::vlen` and + // `host_isa::vecs_count`. + h->sub(h->rsp, get_max_vecs_count() * get_vec_length()); + for (size_t i = 0; i < get_max_vecs_count(); ++i) + h->uni_vmovups(h->ptr[h->rsp + i * get_vec_length()], Zmm(i)); + + const auto data_ptr = [&](Xmm xmm, Xbyak::Reg64 reg, size_t bytes_offset) { + h->uni_vmovq(reg, xmm); + if (bytes_offset) h->add(reg, bytes_offset); + }; +#ifdef _WIN32 + const auto push_value = [&](size_t value, size_t index) { + // Firstly we need to move integer to GPR. Then we can move value from GPR to stack + h->mov(abi_not_param1, value); + h->mov(h->qword[h->rsp + index * gpr_size], abi_not_param1); + }; +#endif + + size_t num_args_passed_on_stack = 0; + // save function address in gpr to pass in call instruction + const auto &kernel_overload = static_cast(execute); + h->mov(h->rbp, reinterpret_cast(kernel_overload)); + // todo: several of addr_{A, B, C} could be also abi_paramX, so one of them could be corrupted + // if moving directly h->uni_vmovq(abi_paramX, adr_X). Save them to vector regs to avoid corruption. + // It's likely that a more efficient solution exists. + h->uni_vmovq(Xmm(0), src); + h->uni_vmovq(Xmm(1), dst); + if (m_with_comp) + h->uni_vmovq(Xmm(2), comp); + // todo: Windows ABI : requires different num of arguments passed in regs and on the stack. Need to align. + h->mov(abi_param1, reinterpret_cast(kernel)); + + data_ptr(Xmm(0), abi_param2, offset_in); + data_ptr(Xmm(1), abi_param3, offset_out); + if (m_with_comp) { + data_ptr(Xmm(2), abi_param4, offset_comp); + } else { + h->mov(abi_param4, reinterpret_cast(nullptr)); + } + +#ifdef _WIN32 + // Before function call we should allocate stack area for + // - register parameters - ABI parameters (shadow space) + // - stack parameters - remaining parameters + num_args_passed_on_stack = 6; // count of function kernel_overload() parameters + size_t abi_param_count = sizeof(abi_param_regs) / sizeof(abi_param_regs[0]); + + h->sub(h->rsp, num_args_passed_on_stack * gpr_size); + push_value(N, abi_param_count + 0); + push_value(K, abi_param_count + 1); +#else + h->mov(abi_param5, N); + h->mov(abi_param6, K); +#endif + // align stack on 16-byte as ABI requires + // note that RBX must not be changed by the callee + h->mov(h->rbx, h->rsp); + h->and_(h->rbx, 0xf); + h->sub(h->rsp, h->rbx); + + h->call(h->rbp); + + h->add(h->rsp, h->rbx); + if (num_args_passed_on_stack > 0) + h->add(h->rsp, gpr_size * num_args_passed_on_stack); + // restore vector registers + for (int i = static_cast(get_max_vecs_count()) - 1; i >= 0; --i) { + h->uni_vmovups(Zmm(i), h->ptr[h->rsp + i * get_vec_length()]); + } + h->add(h->rsp, (get_max_vecs_count()) * get_vec_length()); + + // restore k registers + for (int i = n_k_regs_to_save - 1; i >= 0; --i) { + if (mayiuse(avx512_core)) + h->kmovq(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); + else + h->kmovw(Opmask(i), h->ptr[h->rsp + i * k_mask_size]); + } + h->add(h->rsp, n_k_regs_to_save * k_mask_size); + + // restore gpr registers + for (int i = n_gprs_to_save - 1; i >= 0; --i) + h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]); + h->add(h->rsp, n_gprs_to_save * gpr_size); +} + +void BrgemmCopyBEmitter::execute(matmul::jit_brgemm_matmul_copy_b_t *kernel, const void *src, + const void *dst, const void *comp, size_t N, size_t K) { + if (!kernel) + IE_THROW() << "Kernel for `brgemm_copy_b` hasn't been created"; + + auto ctx = dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t::ctx_t(); + ctx.current_N_blk = N; + ctx.src = src; + ctx.tr_src = dst; + ctx.compensation_ptr = comp; + ctx.zp_a_compensation_ptr = nullptr; + ctx.zp_a_neg_value_ptr = nullptr; + ctx.current_K_start = 0; + ctx.current_K_iters = K; + + (*kernel)(&ctx); } HorizonMaxEmitter::HorizonMaxEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n) : diff --git a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp index cae08b3fe43..0f00eb6f704 100644 --- a/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp +++ b/src/plugins/intel_cpu/src/emitters/jit_snippets_emitters.hpp @@ -321,17 +321,13 @@ class BrgemmEmitter : public jit_emitter { public: BrgemmEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); - size_t get_inputs_num() const override {return 2;} - static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { - return {{element::f32, element::f32}}; - } + size_t get_inputs_num() const override { return m_with_scratch ? 3 : 2; } + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr); private: void emit_impl(const std::vector& in, const std::vector& out) const override; - template - void emit_isa(const std::vector &in, const std::vector &out) const; std::vector io_data_size {}; struct brgemmCtx { size_t M, N, K, LDA, LDB, LDC; @@ -342,29 +338,68 @@ private: float beta; }; void initBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, bool use_amx) const; - template - void callBrgemm(brgemmCtx& ctx, std::unique_ptr& brgKernel, - const void* pin0, const void* pin1, void* pout, void* wsp) const; size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) const; - template - void emit_brgemm_kernel_call(const dnnl::impl::cpu::x64::brgemm_kernel_t *brg_kernel, int bs, - Xbyak::Reg64 addr_A, Xbyak::Reg64 addr_B, - const dnnl::impl::cpu::x64::brgemm_batch_element_t *batch, Xbyak::Reg64 addr_C, void *scratch, - const size_t in0_kernel_offset, const size_t in1_kernel_offset, const size_t out0_kernel_offset) const; - static void kernel_execute(const dnnl::impl::cpu::x64::brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C); + + void emit_brgemm_kernel_call(const dnnl::impl::cpu::x64::brgemm_kernel_t* brg_kernel, const brgemmCtx& ctx, + Xbyak::Reg64 addr_A, Xbyak::Reg64 addr_B, Xbyak::Reg64 scratch, Xbyak::Reg64 addr_C, + const size_t in0_kernel_offset, const size_t in1_kernel_offset, + const size_t in2_kernel_offset, const size_t out0_kernel_offset) const; + static void kernel_execute(const dnnl::impl::cpu::x64::brgemm_kernel_t *brg_kernel, const void *A, const void *B, void *C, void *scratch, int with_comp); + static constexpr size_t BRGEMM_KERNELS_NUM = 8; static constexpr size_t matmulOptimalM = 32; - brgemmCtx brgCtxs0[BRGEMM_KERNELS_NUM]; - std::unique_ptr brgKernels0[BRGEMM_KERNELS_NUM]; + brgemmCtx m_brgCtxs0[BRGEMM_KERNELS_NUM]; + std::unique_ptr m_brgKernels0[BRGEMM_KERNELS_NUM]; - size_t M, M_blk, M_tail; - size_t K, K_blk, K_tail; - size_t N, N_blk, N_tail; - size_t brg0VnniFactor; + size_t m_M, m_M_blk, m_M_tail; + size_t m_K, m_K_blk, m_K_tail; + size_t m_N, m_N_blk, m_N_tail; + size_t m_brg0VnniFactor; - size_t load_offset_a = 0lu; - size_t load_offset_b = 0lu; - size_t store_offset_c = 0lu; + bool m_with_scratch = false; + bool m_with_comp = false; + + size_t m_load_offset_a = 0lu; + size_t m_load_offset_b = 0lu; + size_t m_load_offset_scratch = 0lu; + size_t m_store_offset_c = 0lu; +}; + +class BrgemmCopyBEmitter : public jit_emitter { +public: + BrgemmCopyBEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa, const std::shared_ptr& n); + + size_t get_inputs_num() const override {return 1;} + static std::set> get_supported_precisions(const std::shared_ptr& node = nullptr) { + return {{element::i8}, {element::bf16}}; + } + +private: + void emit_impl(const std::vector& in, + const std::vector& out) const override; + + void init_brgemm_copy(std::unique_ptr& kernel, + size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, + bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1) const; + void emit_kernel_call(const dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel, + Xbyak::Reg64 src, Xbyak::Reg64 dst, Xbyak::Reg64 comp, size_t N, size_t K, + size_t offset_in, size_t offset_out, size_t offset_comp) const; + + static void execute(dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t* kernel, + const void* src, const void* dst, const void* comp, size_t N, size_t K); + + std::unique_ptr m_kernel; + + ov::element::Type m_brgemm_prc_in0, m_brgemm_prc_in1; + size_t m_N, m_N_blk, m_N_tail; + size_t m_K, m_K_blk, m_K_tail; + size_t m_LDB; + size_t m_brgemmVNNIFactor; + bool m_with_comp = false; + + size_t m_in_offset = 0lu; + size_t m_out_offset = 0lu; + size_t m_comp_offset = 0lu; }; class HorizonMaxEmitter : public jit_emitter { diff --git a/src/plugins/intel_cpu/src/extension.cpp b/src/plugins/intel_cpu/src/extension.cpp index fd0c07ea55b..a9d6e08377c 100644 --- a/src/plugins/intel_cpu/src/extension.cpp +++ b/src/plugins/intel_cpu/src/extension.cpp @@ -11,6 +11,8 @@ #include "ngraph_transformations/op/mha.hpp" #include "snippets_transformations/op/load_convert.hpp" #include "snippets_transformations/op/store_convert.hpp" +#include "snippets_transformations/op/brgemm_cpu.hpp" +#include "snippets_transformations/op/brgemm_copy_b.hpp" #include #include @@ -54,6 +56,8 @@ std::map Extension::getOpSets() { NGRAPH_OP(LoadConvertTruncation, ov::intel_cpu) NGRAPH_OP(StoreConvertSaturation, ov::intel_cpu) NGRAPH_OP(StoreConvertTruncation, ov::intel_cpu) + NGRAPH_OP(BrgemmCPU, ov::intel_cpu) + NGRAPH_OP(BrgemmCopyB, ov::intel_cpu) #undef NGRAPH_OP return opset; @@ -132,9 +136,9 @@ std::map Extension::getOpSets() { #define NGRAPH_OP(NAME, NAMESPACE) opset.insert(); NGRAPH_OP(Brgemm, ngraph::snippets::op) + NGRAPH_OP(Buffer, ngraph::snippets::op) NGRAPH_OP(BroadcastLoad, ngraph::snippets::op) NGRAPH_OP(BroadcastMove, ngraph::snippets::op) - NGRAPH_OP(Buffer, ngraph::snippets::op) NGRAPH_OP(ConvertSaturation, ngraph::snippets::op) NGRAPH_OP(ConvertTruncation, ngraph::snippets::op) NGRAPH_OP(Fill, ngraph::snippets::op) diff --git a/src/plugins/intel_cpu/src/nodes/subgraph.cpp b/src/plugins/intel_cpu/src/nodes/subgraph.cpp index 8eb425e7ec4..41b79e9e941 100644 --- a/src/plugins/intel_cpu/src/nodes/subgraph.cpp +++ b/src/plugins/intel_cpu/src/nodes/subgraph.cpp @@ -25,6 +25,7 @@ #include "utils/cpu_utils.hpp" #include "snippets_transformations/fuse_load_store_and_convert.hpp" #include "snippets_transformations/mul_add_to_fma.hpp" +#include "snippets_transformations/brgemm_to_brgemm_cpu.hpp" #include "snippets_transformations/remove_converts.hpp" #include "ngraph_transformations/convert_to_swish_cpu.hpp" @@ -536,6 +537,7 @@ void Snippet::generate(const jit_snippets_compile_args* jcp) { pre_dialect.register_pass(); ov::pass::Manager post_dialect; + post_dialect.register_pass(); ov::pass::Manager post_precision; post_precision.register_pass(); diff --git a/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp new file mode 100644 index 00000000000..63779e5848b --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.cpp @@ -0,0 +1,96 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/itt.hpp" + +#include "brgemm_to_brgemm_cpu.hpp" +#include "snippets/snippets_isa.hpp" +#include "snippets/utils.hpp" +#include "op/brgemm_copy_b.hpp" +#include "op/brgemm_cpu.hpp" + +#include "ngraph/rt_info.hpp" +#include "ngraph/pattern/op/wrap_type.hpp" + +#include + +#include "cpu_shape.h" +#include "utils/general_utils.h" + + +namespace ov { +namespace intel_cpu { + +pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() { + MATCHER_SCOPE(BrgemmToBrgemmCPU); + + auto m_brgemm = ngraph::pattern::wrap_type(); + + auto callback = [=](ngraph::pattern::Matcher& m) { + OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::BrgemmToBrgemmCPU") + const auto node = m.get_match_root(); + const auto brgemm = ov::as_type_ptr(node); + const auto brgemm_plugin = ov::as_type_ptr(node); + if (!brgemm || brgemm_plugin) + throw ov::Exception("BrgemmCPU cannot be in body before BrgemmToBrgemmCPU pass"); + + if (brgemm->is_dynamic()) { + return false; + } + + const auto dimsMatMulIn0 = ngraph::snippets::utils::get_port_planar_shape(brgemm->input_value(0)).get_shape(); + const auto dimsMatMulIn1 = ngraph::snippets::utils::get_port_planar_shape(brgemm->input_value(1)).get_shape(); + + const auto K = *dimsMatMulIn0.rbegin(); + const auto N = *dimsMatMulIn1.rbegin(); + + const auto element_type_a = brgemm->get_input_element_type(0); + const auto brgemmVNNIFactor = 4 / element_type_a.size(); + const bool isAMXSupported = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx); + const bool with_amx = isAMXSupported && element_type_a != ov::element::f32 && (K % brgemmVNNIFactor == 0) && (N % brgemmVNNIFactor == 0); + const bool with_comp = element_type_a == ov::element::i8 && !with_amx; + + const auto offset_a = brgemm->get_offset_a(); + const auto offset_b = brgemm->get_offset_b(); + const auto offset_c = brgemm->get_offset_c(); + + std::shared_ptr brgemm_cpu = nullptr; + if (element_type_a == ov::element::f32) { + brgemm_cpu = std::make_shared(brgemm->input_value(0), brgemm->input_value(1), BrgemmCPU::Type::Floating, + offset_a, offset_b, offset_c); + } else { + const auto layoutIn1 = ngraph::snippets::utils::get_node_output_layout(brgemm->input_value(1).get_node_shared_ptr()); + const auto copy_b_type = with_comp ? BrgemmCopyB::WithCompensations : BrgemmCopyB::OnlyRepacking; + const auto brgemmRepackIn1 = std::make_shared(brgemm->input_value(1), element_type_a, copy_b_type, offset_b); + const auto buffer = std::make_shared(brgemmRepackIn1->output(0)); + + if (with_amx) { + const auto scratch = std::make_shared(ov::Shape{BrgemmCPU::SCRATCH_BYTE_SIZE}); + brgemm_cpu = std::make_shared(brgemm->input_value(0), buffer, scratch, BrgemmCPU::Type::AMX, + offset_a, offset_b, offset_c); + } else if (with_comp) { + const auto scratch = std::make_shared(brgemmRepackIn1->output(1)); + brgemm_cpu = std::make_shared(brgemm->input_value(0), buffer, scratch, BrgemmCPU::Type::WithCompensations, + offset_a, offset_b, offset_c); + } else if (one_of(element_type_a, ov::element::u8, ov::element::bf16)) { + brgemm_cpu = std::make_shared(brgemm->input_value(0), buffer, BrgemmCPU::Type::WithDataRepacking, + offset_a, offset_b, offset_c); + } else { + IE_THROW() << "Invalid configuration for BRGEMM CPU"; + } + } + + brgemm_cpu->set_friendly_name(brgemm->get_friendly_name()); + ngraph::snippets::utils::set_output_layout(brgemm_cpu->output(0), ngraph::snippets::utils::get_node_output_layout(brgemm)); + ngraph::copy_runtime_info(brgemm, brgemm_cpu); + ngraph::replace_node(brgemm, brgemm_cpu); + + return true; + }; + + auto m = std::make_shared(m_brgemm, matcher_name); + register_matcher(m, callback); +} +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.hpp b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.hpp new file mode 100644 index 00000000000..c400d2d0790 --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/brgemm_to_brgemm_cpu.hpp @@ -0,0 +1,45 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "ngraph/pass/graph_rewrite.hpp" +#include "ngraph/pattern/matcher.hpp" + +namespace ov { +namespace intel_cpu { +namespace pass { + +/** + * @interface BrgemmToBrgemmCPU + * @brief The pass decompose Snippets Brgemm to specific subgraph that depends on ISA and input precisions: + * - f32|f32: + * BrgemmCPU + * - u8|i8 or bf16|bf16 (non-AMX system): + * \ BrgemmCopyB (the operation for data repacking) + * \ Buffer + * BrgemmCPU + * - i8|i8 (non-AMX system) - needs compensations: + * \ BrgemmCopyB + * \ / \ + * \ Buffer (with repacked data) Buffer (with compensations) + * \ | / + * BrgemmCPU + * - u8|i8, i8|i8 or bf16|bf16 on AMX system: + * \ BrgemmCopyB + * \ Buffer (with repacked data) Buffer (with new memory) + * \ | / + * BrgemmCPU + * @ingroup snippets + */ +class BrgemmToBrgemmCPU: public ngraph::pass::MatcherPass { +public: + OPENVINO_RTTI("BrgemmToBrgemmCPU", "0"); + BrgemmToBrgemmCPU(); +}; + + +} // namespace pass +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp b/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp index b47fcfe73da..0c64a20b655 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/fuse_load_store_and_convert.cpp @@ -10,20 +10,17 @@ #include "snippets_transformations/op/load_convert.hpp" #include "snippets_transformations/op/store_convert.hpp" -#include "ngraph/opsets/opset1.hpp" #include "ngraph/rt_info.hpp" #include "ngraph/pattern/op/wrap_type.hpp" ov::intel_cpu::pass::FuseLoadConvert::FuseLoadConvert() { MATCHER_SCOPE(FuseLoadConvert); - auto param_pattern = ngraph::pattern::wrap_type(); - auto load_pattern = ngraph::pattern::wrap_type({param_pattern}); + auto load_pattern = ngraph::pattern::wrap_type(); auto convert_pattern = ngraph::pattern::wrap_type({load_pattern}); auto callback = [=](ngraph::pattern::Matcher& m) { OV_ITT_SCOPED_TASK(ngraph::pass::itt::domains::SnippetsTransform, "ov::intel_cpu::pass::FuseLoadConvert") auto& pm = m.get_pattern_value_map(); - const auto param = pm.at(param_pattern).get_node_shared_ptr(); const auto load_shared = pm.at(load_pattern).get_node_shared_ptr(); if (!load_shared || load_shared->output(0).get_target_inputs().size() != 1) { return false; @@ -40,12 +37,12 @@ ov::intel_cpu::pass::FuseLoadConvert::FuseLoadConvert() { std::shared_ptr load_convert = nullptr; if (const auto convert_saturation = std::dynamic_pointer_cast(convert)) { - load_convert = std::make_shared(param, + load_convert = std::make_shared(load->input_value(0), convert_saturation->get_destination_type(), load->get_count(), load->get_offset()); } else if (const auto convert_truncation = std::dynamic_pointer_cast(convert)) { - load_convert = std::make_shared(param, + load_convert = std::make_shared(load->input_value(0), convert_truncation->get_destination_type(), load->get_count(), load->get_offset()); } else { @@ -102,7 +99,6 @@ ov::intel_cpu::pass::FuseStoreConvert::FuseStoreConvert() { "Type of Convert op is undefined. Supports only fusing Store and ConvertTruncation or ConvertSaturation ops"); } - if (!store_convert) return false; diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp new file mode 100644 index 00000000000..0e4004395e1 --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.cpp @@ -0,0 +1,78 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/itt.hpp" +#include "snippets/utils.hpp" + +#include "brgemm_copy_b.hpp" + +#include "utils/general_utils.h" + +using namespace std; +using namespace ov; + +intel_cpu::BrgemmCopyB::BrgemmCopyB(const Output& x, const element::Type src_type, const Type type, + const size_t offset_in, const size_t offset_out0, const size_t offset_out1) + : ngraph::snippets::op::MemoryAccess({x}, 1, type == Type::WithCompensations ? 2 : 1), m_type(type), m_src_type(src_type) { + set_output_size(get_output_port_count()); + m_input_ports.resize(get_input_size()); + m_output_ports.resize(get_output_size()); + set_input_port_descriptor({0, offset_in}, 0); + set_output_port_descriptor({0, offset_out0}, 0); + if (is_with_compensations()) { + set_output_port_descriptor({0, offset_out1}, 1); + } + constructor_validate_and_infer_types(); +} + +bool intel_cpu::BrgemmCopyB::visit_attributes(AttributeVisitor& visitor) { + INTERNAL_OP_SCOPE(BrgemmRepack_visit_attributes); + MemoryAccess::visit_attributes(visitor); + visitor.on_attribute("src_type", m_src_type); + return true; +} + +void intel_cpu::BrgemmCopyB::validate_and_infer_types() { + INTERNAL_OP_SCOPE(BrgemmRepack_validate_and_infer_types); + + const auto element_type = get_input_element_type(0); + NGRAPH_CHECK(one_of(element_type, element::bf16, element::i8), + "BrgemmCopyB doesn't support element type" + element_type.get_type_name()); + + const auto pshape = ngraph::snippets::utils::get_port_planar_shape(input_value(0)); + if (pshape.is_dynamic()) { + set_output_type(0, element_type, ov::PartialShape{ov::Dimension::dynamic()}); + if (is_with_compensations()) { + set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension::dynamic()}); + } + return; + } + + const auto shape = pshape.get_shape(); + const auto N = *shape.rbegin(); + const auto K = *(shape.rbegin() + 1); + const auto N_blk = element_type == element::bf16 ? 32 : 64; + const auto brgemmVNNIFactor = 4 / m_src_type.size(); + + set_output_type(0, element_type, ov::PartialShape{ov::Dimension(rnd_up(K, brgemmVNNIFactor)), + ov::Dimension(rnd_up(N, N_blk))}); + if (is_with_compensations()) { + set_output_type(1, ov::element::f32, ov::PartialShape{ov::Dimension(rnd_up(N, N_blk))}); + } +} + +std::shared_ptr intel_cpu::BrgemmCopyB::clone_with_new_inputs(const OutputVector& new_args) const { + INTERNAL_OP_SCOPE(BrgemmRepack_clone_with_new_inputs); + check_new_args_count(this, new_args); + return std::make_shared(new_args.at(0), m_src_type, m_type, + get_offset_in(), + get_offset_out(), + is_with_compensations() ? get_offset_compensations() : 0); +} + +size_t intel_cpu::BrgemmCopyB::get_offset_compensations() const { + OPENVINO_ASSERT(is_with_compensations() && get_output_size() == 2, + "The offset for compensations must be in BrgemmCopyB only with compensations and 2 outputs!"); + return get_output_offset(1); +} diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp new file mode 100644 index 00000000000..d8db828b4a3 --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_copy_b.hpp @@ -0,0 +1,51 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "snippets/op/memory_access.hpp" + +namespace ov { +namespace intel_cpu { + +/** +* @interface BrgemmCopyB +* @brief The operation for data repacking of Brgemm with input non-fp32 precisions. + The CPU Generator uses oneDNN primitives for generation code of Brgemm. + OneDNN requiers data repacking for second input of Brgemm with input non-fp32 precisions. +* @ingroup snippets +*/ +class BrgemmCopyB : public ngraph::snippets::op::MemoryAccess { +public: + OPENVINO_OP("BrgemmCopyB", "SnippetsOpset", MemoryAccess); + + enum Type { + OnlyRepacking, // Just data repacking - one output + WithCompensations, // Repack data and caclulate compensations - 2 outputs (is needed for BrgemmCPU with compensations) + }; + + BrgemmCopyB(const Output& x, const element::Type src_type, const Type type = Type::OnlyRepacking, + const size_t offset_in = 0lu, const size_t offset_out0 = 0lu, const size_t offset_out1 = 0lu); + BrgemmCopyB() = default; + + size_t get_offset_in() const { return get_input_offset(0); } + size_t get_offset_out() const { return get_output_offset(0); } + size_t get_offset_compensations() const; + + Type get_type() const { return m_type; } + element::Type get_src_element_type() const { return m_src_type; } + bool is_with_compensations() const { return m_type == Type::WithCompensations; } + + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + bool has_evaluate() const override { return false; } + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + +private: + Type m_type = Type::OnlyRepacking; + element::Type m_src_type = ov::element::undefined; // src element type of the corresponding BRGEMM +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp new file mode 100644 index 00000000000..67e85394063 --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.cpp @@ -0,0 +1,117 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "snippets/itt.hpp" +#include "brgemm_cpu.hpp" +#include "ngraph/runtime/host_tensor.hpp" +#include "openvino/core/rt_info.hpp" +#include "snippets/utils.hpp" +#include "utils/general_utils.h" + + +namespace ov { +namespace intel_cpu { + +BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Type type, + const size_t offset_a, const size_t offset_b, const size_t offset_c) + : Brgemm(), m_type(type) { + // We call default ctor of Brgemm class to avoid incorrect shape infer in constructor_validate_and_type_infer() call + set_arguments({A, B}); + set_output_size(1); + m_input_ports.resize(get_input_size()); + m_output_ports.resize(get_output_size()); + set_input_port_descriptor({0, offset_a}, 0); + set_input_port_descriptor({0, offset_b}, 1); + set_output_port_descriptor({0, offset_c}, 0); + constructor_validate_and_infer_types(); +} + +BrgemmCPU::BrgemmCPU(const Output& A, const Output& B, const Output& scratch, const Type type, + const size_t offset_a, const size_t offset_b, const size_t offset_scratch, const size_t offset_c) + : Brgemm(), m_type(type) { + set_arguments({A, B, scratch}); + set_output_size(1); + m_input_ports.resize(get_input_size()); + m_output_ports.resize(get_output_size()); + set_input_port_descriptor({0, offset_a}, 0); + set_input_port_descriptor({0, offset_b}, 1); + set_output_port_descriptor({0, offset_c}, 0); + set_input_port_descriptor({0, offset_scratch}, 2); + constructor_validate_and_infer_types(); +} + +void BrgemmCPU::validate_and_infer_types() { + INTERNAL_OP_SCOPE(BrgemmCPU_validate_and_infer_types); + // If no leading dimensions are provided, assume dense row-major inputs-outputs + NODE_VALIDATION_CHECK(this, get_input_partial_shape(0).is_static() && get_input_partial_shape(1).is_static(), + "BrgemmCPU currently supports only static shapes."); + + OPENVINO_ASSERT(implication(one_of(m_type, Type::Floating, Type::WithDataRepacking), get_input_size() == 2), + "BrgemmCPU expects 2 inputs in cases, when input precisions are f32|f32, u8|i8 or bf16|bf16 (non-AMX system)"); + OPENVINO_ASSERT(implication(one_of(m_type, Type::WithCompensations, Type::AMX), get_input_size() == 3), + "BrgemmCPU expects 3 inputs with input precisions i8|i8 and bf16|bf16 on AMX system"); + + const auto brgemm_copy = is_with_data_repacking() ? get_brgemm_copy() : nullptr; + std::vector planar_input_shapes = { + ngraph::snippets::utils::get_port_planar_shape(input_value(0)), + ngraph::snippets::utils::get_port_planar_shape(brgemm_copy ? brgemm_copy->input_value(0) : input_value(1)) + }; + + auto output_shape = get_output_partial_shape(planar_input_shapes); + const auto& output_layout = ngraph::snippets::utils::get_node_output_layout(this); + set_output_type(0, + get_output_type(), + ngraph::snippets::utils::get_reordered_planar_shape(output_shape, output_layout)); + + //Additional check for 3rd input + if (one_of(m_type, Type::WithCompensations, Type::AMX)) { + const auto shape = get_input_partial_shape(2); + NGRAPH_CHECK(shape.is_static(), "BRGEMM Scratch must have static shape"); + const auto type = get_input_element_type(2); + if (is_with_compensations()) { + const auto element_type_b = get_input_element_type(0); + const auto shape_b = planar_input_shapes[1].get_shape(); + const auto N = *shape_b.rbegin(); + const auto N_blk = element_type_b == element::f32 ? N : + element_type_b == element::bf16 ? 32 : 64; + const auto expected_shape = ov::Shape{rnd_up(N, N_blk)}; + const auto expected_type = ov::element::f32; + NGRAPH_CHECK(expected_shape == shape.get_shape() && expected_type == type, + "BRGEMM Scratch with compensations must have shape {rnd_up(N, N_blk)} and FP32 element type"); + } else { + NGRAPH_CHECK(ngraph::shape_size(shape.get_shape()) == SCRATCH_BYTE_SIZE && type == ov::element::u8, + "BRGEMM Scratch for space workplace must be static, have U8 element type and size is equal to " + std::to_string(SCRATCH_BYTE_SIZE)); + } + } +} + +std::shared_ptr BrgemmCPU::clone_with_new_inputs(const OutputVector& new_args) const { + INTERNAL_OP_SCOPE(BrgemmCPU_clone_with_new_inputs); + check_new_args_count(this, new_args); + std::shared_ptr new_node = nullptr; + if (!is_with_scratchpad()) { + new_node = std::make_shared(new_args.at(0), new_args.at(1), m_type, + get_offset_a(), get_offset_b(), get_offset_c()); + } else { + new_node = std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), m_type, + get_offset_a(), get_offset_b(), get_offset_scratch(), get_offset_c()); + } + return new_node; +} + +std::shared_ptr BrgemmCPU::get_brgemm_copy() const { + OPENVINO_ASSERT(one_of(m_type, Type::WithDataRepacking, Type::WithCompensations, Type::AMX), "Brgemm doesn't need BrgemmCopyB"); + if (const auto buffer = ov::as_type_ptr(get_input_node_shared_ptr(1))) { + return ov::as_type_ptr(buffer->get_input_node_shared_ptr(0)); + } + throw ov::Exception("BrgemmCopyB hasn't been found!"); +} + +size_t BrgemmCPU::get_offset_scratch() const { + OPENVINO_ASSERT(is_with_scratchpad() && get_input_size() == 3, "Offset of scratchpad must be only in Brgemm with scratchpad on 3rd input"); + return get_input_offset(2); +} + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp new file mode 100644 index 00000000000..2081ca25c75 --- /dev/null +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/brgemm_cpu.hpp @@ -0,0 +1,55 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "snippets/op/brgemm.hpp" +#include "brgemm_copy_b.hpp" + +namespace ov { +namespace intel_cpu { + +/** + * @interface BrgemmCPU + * @brief BrgemmCPU is a batch-reduced matrix multiplication with the support of arbitrary strides between matrices rows + * with support of several precisions on plugin level + * @ingroup snippets + */ +class BrgemmCPU : public ngraph::snippets::op::Brgemm { +public: + OPENVINO_OP("BrgemmCPU", "SnippetsOpset", ngraph::snippets::op::Brgemm); + + enum Type { + Floating, // f32|f32 + WithDataRepacking, // u8|i8 or bf16|bf16 (non-AMX system) - needs BrgemmCopyB on second input for data repacking + WithCompensations, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations + AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad + }; + + BrgemmCPU(const Output& A, const Output& B, const Type type, + const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_c = 0); + BrgemmCPU(const Output& A, const Output& B, const Output& scratch, const Type type, + const size_t offset_a = 0, const size_t offset_b = 0, const size_t offset_scratch = 0, const size_t offset_c = 0); + BrgemmCPU() = default; + + void validate_and_infer_types() override; + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + + Type get_type() const { return m_type; } + bool is_with_compensations() const { return m_type == Type::WithCompensations; } + bool is_with_data_repacking() const { return m_type != Type::Floating; } + bool is_amx() const { return m_type == Type::AMX; } + bool is_with_scratchpad() const { return is_with_compensations() || is_amx(); } + + size_t get_offset_scratch() const; + std::shared_ptr get_brgemm_copy() const; + + constexpr static size_t SCRATCH_BYTE_SIZE = 32 * 1024; + +private: + Type m_type = Type::Floating; +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/load_convert.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/load_convert.cpp index dbb8046f636..a71181e8e2b 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/load_convert.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/load_convert.cpp @@ -19,6 +19,7 @@ intel_cpu::LoadConvertSaturation::LoadConvertSaturation(const Output& x, c bool intel_cpu::LoadConvertSaturation::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(LoadConvert_visit_attributes); + MemoryAccess::visit_attributes(visitor); visitor.on_attribute("destination_type", m_destination_type); return true; } @@ -31,7 +32,8 @@ void intel_cpu::LoadConvertSaturation::validate_and_infer_types() { std::shared_ptr intel_cpu::LoadConvertSaturation::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(LoadConvert_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_destination_type, m_count, m_offset); + return std::make_shared( + new_args.at(0), m_destination_type, get_count(), get_offset()); } intel_cpu::LoadConvertTruncation::LoadConvertTruncation(const Output& x, const ov::element::Type& destination_type, @@ -42,6 +44,7 @@ intel_cpu::LoadConvertTruncation::LoadConvertTruncation(const Output& x, c bool intel_cpu::LoadConvertTruncation::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(LoadConvert_visit_attributes); + MemoryAccess::visit_attributes(visitor); visitor.on_attribute("destination_type", m_destination_type); return true; } @@ -54,5 +57,6 @@ void intel_cpu::LoadConvertTruncation::validate_and_infer_types() { std::shared_ptr intel_cpu::LoadConvertTruncation::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(LoadConvert_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_destination_type, m_count, m_offset); + return std::make_shared( + new_args.at(0), m_destination_type, get_count(), get_offset()); } diff --git a/src/plugins/intel_cpu/src/snippets_transformations/op/store_convert.cpp b/src/plugins/intel_cpu/src/snippets_transformations/op/store_convert.cpp index 52921e681e9..d7e1c9e4b05 100644 --- a/src/plugins/intel_cpu/src/snippets_transformations/op/store_convert.cpp +++ b/src/plugins/intel_cpu/src/snippets_transformations/op/store_convert.cpp @@ -19,6 +19,7 @@ intel_cpu::StoreConvertSaturation::StoreConvertSaturation(const Output& x, bool intel_cpu::StoreConvertSaturation::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(StoreConvert_visit_attributes); + MemoryAccess::visit_attributes(visitor); visitor.on_attribute("destination_type", m_destination_type); return true; } @@ -31,7 +32,8 @@ void intel_cpu::StoreConvertSaturation::validate_and_infer_types() { std::shared_ptr intel_cpu::StoreConvertSaturation::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(StoreConvert_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_destination_type, m_count, m_offset); + return std::make_shared( + new_args.at(0), m_destination_type, get_count(), get_offset()); } intel_cpu::StoreConvertTruncation::StoreConvertTruncation(const Output& x, const ov::element::Type& destination_type, @@ -42,6 +44,7 @@ intel_cpu::StoreConvertTruncation::StoreConvertTruncation(const Output& x, bool intel_cpu::StoreConvertTruncation::visit_attributes(AttributeVisitor& visitor) { INTERNAL_OP_SCOPE(StoreConvert_visit_attributes); + MemoryAccess::visit_attributes(visitor); visitor.on_attribute("destination_type", m_destination_type); return true; } @@ -54,5 +57,6 @@ void intel_cpu::StoreConvertTruncation::validate_and_infer_types() { std::shared_ptr intel_cpu::StoreConvertTruncation::clone_with_new_inputs(const OutputVector& new_args) const { INTERNAL_OP_SCOPE(StoreConvert_clone_with_new_inputs); check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_destination_type, m_count, m_offset); + return std::make_shared( + new_args.at(0), m_destination_type, get_count(), get_offset()); } diff --git a/src/plugins/intel_cpu/src/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformation_pipeline.cpp index 174a10cd6d0..2bedc4d32df 100644 --- a/src/plugins/intel_cpu/src/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformation_pipeline.cpp @@ -556,7 +556,7 @@ void Transformations::PostLpt() { void Transformations::MainSnippets(void) { if (snippetsMode == Config::SnippetsMode::Disable || - !dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) // snippets are implemeted only for relevant platforms (avx2+ extentions) + !dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) // snippets are implemented only for relevant platforms (avx2+ extensions) return; ngraph::pass::Manager snippetsManager; diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index acbd2310908..5e246017855 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -207,6 +207,10 @@ std::vector disabledTestPatterns() { retVector.emplace_back(R"(.*Snippets.*MHA.*)"); retVector.emplace_back(R"(.*Snippets.*(MatMul|Matmul).*)"); } + if (!InferenceEngine::with_cpu_x86_avx512_core_vnni() && !InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { + // MatMul in Snippets uses BRGEMM that supports i8 only on platforms with VNNI or AMX instructions + retVector.emplace_back(R"(.*Snippets.*MatMulFQ.*)"); + } if (!InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) //TODO: Issue 92895 // on platforms which do not support AMX, we are disabling I8 input tests diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp index 9ab22c79d2e..9d792f35264 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/matmul.cpp @@ -4,6 +4,7 @@ #include "snippets/matmul.hpp" #include "common_test_utils/test_constants.hpp" +#include "ie_system_conf.h" namespace ov { namespace test { @@ -16,49 +17,47 @@ std::vector> input_shapes{ {{3, 1, 32, 14}, {1, 2, 14, 32}}, {{1, 2, 37, 23}, {2, 1, 23, 37}}, {{1, 1, 37, 23}, {1, 2, 23, 33}}, - {{2, 1, 69, 43}, {1, 1, 43, 49}} + {{1, 16, 384, 64}, {1, 16, 64, 384}} }; -std::vector precisions{element::f32}; +static inline std::vector> precisions(bool only_fp32 = true) { + std::vector> prc = { + {element::f32, element::f32}, + }; + if (!only_fp32) { + // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms + if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { + prc.emplace_back(std::vector{element::i8, element::i8}); + prc.emplace_back(std::vector{element::u8, element::i8}); + } + // In Snippets MatMul BF16 is supported only on bf16/AMX platforms + if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) { + prc.emplace_back(std::vector{element::bf16, element::bf16}); + } + } + return prc; +} INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, MatMul, ::testing::Combine( ::testing::ValuesIn(input_shapes), - ::testing::ValuesIn(precisions), - ::testing::Values(1), // MatMu; + ::testing::ValuesIn(precisions(false)), + ::testing::Values(1), // MatMul ::testing::Values(1), // Tokenized MatMul ::testing::Values(CommonTestUtils::DEVICE_CPU)), MatMul::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulFQ, MatMulFQ, + ::testing::Combine( + ::testing::ValuesIn(input_shapes), + ::testing::ValuesIn(precisions()), + ::testing::Values(1), // MatMul; + ::testing::Values(1), // Tokenized MatMul + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + MatMul::getTestCaseName); + INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMulBias, MatMulBias, ::testing::Combine( ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 1, 43, 49}, {1, 1, 69, 49}}), - ::testing::ValuesIn(precisions), - ::testing::Values(1), // Subgraph; - ::testing::Values(1), // Tokenized MatMul+Bias - ::testing::Values(CommonTestUtils::DEVICE_CPU)), - MatMul::getTestCaseName); - -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ExplicitTransposeMatMul, ExplicitTransposeMatMul, - ::testing::Combine( - ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}}), - ::testing::ValuesIn(precisions), - ::testing::Values(1), // Subgraph; - ::testing::Values(1), // Tokenized MatMul+Bias - ::testing::Values(CommonTestUtils::DEVICE_CPU)), - ExplicitTransposeMatMul::getTestCaseName); - -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulBias, ExplicitTransposeMatMulBias, - ::testing::Combine( - ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}, {1, 1, 69, 49}}), - ::testing::ValuesIn(precisions), - ::testing::Values(1), // Subgraph; - ::testing::Values(1), // Tokenized MatMul+Bias - ::testing::Values(CommonTestUtils::DEVICE_CPU)), - MatMul::getTestCaseName); - -INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMulMatMulBias, ExplicitTransposeMulMatMulBias, - ::testing::Combine( - ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}, {1, 2, 1, 1}, {1, 1, 69, 49}}), - ::testing::ValuesIn(precisions), + ::testing::ValuesIn(precisions(false)), ::testing::Values(1), // Subgraph; ::testing::Values(1), // Tokenized MatMul+Bias ::testing::Values(CommonTestUtils::DEVICE_CPU)), diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp index 8e3af45fd52..6423f5a3db4 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/snippets/transpose_matmul.cpp @@ -4,6 +4,7 @@ #include "snippets/transpose_matmul.hpp" #include "common_test_utils/test_constants.hpp" +#include "ie_system_conf.h" namespace ov { namespace test { @@ -11,7 +12,23 @@ namespace snippets { namespace { -std::vector precisions{element::f32}; +static inline std::vector> precisions(bool only_fp32 = true) { + std::vector> prc = { + {element::f32, element::f32}, + }; + if (!only_fp32) { + // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms + if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { + prc.emplace_back(std::vector{element::i8, element::i8}); + prc.emplace_back(std::vector{element::u8, element::i8}); + } + // In Snippets MatMul BF16 is supported only on bf16/AMX platforms + if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) { + prc.emplace_back(std::vector{element::bf16, element::bf16}); + } + } + return prc; +} namespace transpose_zero_input { std::vector> transpose_input_shapes{ {{1, 49, 2, 23}, {2, 2, 23, 39}} @@ -20,11 +37,23 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, ::testing::Combine( ::testing::ValuesIn(transpose_input_shapes), ::testing::Values(0), // Transpose on 0th Matmul input - ::testing::ValuesIn(precisions), - ::testing::Values(1), // MatMul; + ::testing::ValuesIn(precisions(false)), + ::testing::Values(1), // MatMul ::testing::Values(1), // Tokenized MatMul + FusedTranspose ::testing::Values(CommonTestUtils::DEVICE_CPU)), TransposeMatMul::getTestCaseName); + +// TODO: FuseTransposeToBrgemm supports fusing only if Transpose is before Parameter in cases when Transpose is on input at the moment +// When we support the branch Parameter->FQ->Transpose->MatMul[0th input], uncomment this test case please +// INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulFQ, TransposeMatMulFQ, +// ::testing::Combine( +// ::testing::ValuesIn(transpose_input_shapes), +// ::testing::Values(0), // Transpose on 0th Matmul input +// ::testing::Values(ov::element::i8), +// ::testing::Values(1), // MatMul +// ::testing::Values(1), // Tokenized MatMul + FusedTranspose +// ::testing::Values(CommonTestUtils::DEVICE_CPU)), +// TransposeMatMulFQ::getTestCaseName); } // namespace transpose_zero_input namespace transpose_first_input { @@ -35,11 +64,21 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, ::testing::Combine( ::testing::ValuesIn(transpose_input_shapes), ::testing::Values(1), // Transpose on 1st Matmul input - ::testing::ValuesIn(precisions), - ::testing::Values(1), // MatMu; + ::testing::ValuesIn(precisions(false)), + ::testing::Values(1), // MatMul ::testing::Values(1), // Tokenized MatMul + FusedTranspose ::testing::Values(CommonTestUtils::DEVICE_CPU)), TransposeMatMul::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulFQ, TransposeMatMulFQ, + ::testing::Combine( + ::testing::ValuesIn(transpose_input_shapes), + ::testing::Values(1), // Transpose on 1st Matmul input + ::testing::ValuesIn(precisions()), + ::testing::Values(1), // MatMul + ::testing::Values(1), // Tokenized MatMul + FusedTranspose + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + TransposeMatMulFQ::getTestCaseName); } // namespace transpose_first_input namespace transpose_output { @@ -50,13 +89,64 @@ INSTANTIATE_TEST_SUITE_P(smoke_Snippets_MatMult, TransposeMatMul, ::testing::Combine( ::testing::ValuesIn(transpose_input_shapes), ::testing::Values(2), // Transpose on Matmul output - ::testing::ValuesIn(precisions), - ::testing::Values(1), // MatMu; + ::testing::ValuesIn(precisions()), + ::testing::Values(1), // MatMul ::testing::Values(1), // Tokenized MatMul + FusedTranspose ::testing::Values(CommonTestUtils::DEVICE_CPU)), TransposeMatMul::getTestCaseName); + +// TODO: At the moment we doesn't support the branch MatMul[output]->Transpose->FQ. +// When we add support, uncomment this test case please +// INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulFQ, TransposeMatMulFQ, +// ::testing::Combine( +// ::testing::ValuesIn(transpose_input_shapes), +// ::testing::Values(2), // Transpose on Matmul output +// ::testing::Values(ov::element::i8), +// ::testing::Values(1), // MatMul +// ::testing::Values(1), // Tokenized MatMul + FusedTranspose +// ::testing::Values(CommonTestUtils::DEVICE_CPU)), +// TransposeMatMulFQ::getTestCaseName); } // namespace transpose_output +namespace explicit_transpose { +static inline std::vector> precisions(bool only_fp32 = true) { + std::vector> prc = { + {element::f32, element::f32}, + }; + if (!only_fp32) { + // In Snippets MatMul INT8 is supported only on VNNI/AMX platforms + if (InferenceEngine::with_cpu_x86_avx512_core_vnni() || InferenceEngine::with_cpu_x86_avx512_core_amx_int8()) { + prc.emplace_back(std::vector{element::i8, element::i8}); + prc.emplace_back(std::vector{element::u8, element::i8}); + } + // In Snippets MatMul BF16 is supported only on bf16/AMX platforms + if (InferenceEngine::with_cpu_x86_bfloat16() || InferenceEngine::with_cpu_x86_avx512_core_amx_bf16()) { + prc.emplace_back(std::vector{element::bf16, element::bf16}); + } + } + return prc; +} +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_ExplicitTransposeMatMul, ExplicitTransposeMatMul, + ::testing::Combine( + ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}}), + ::testing::Values(1), // Transpose on second input + ::testing::ValuesIn(precisions()), + ::testing::Values(1), // Subgraph; + ::testing::Values(1), // Tokenized MatMul+Bias + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ExplicitTransposeMatMul::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_TransposeMatMulBias, ExplicitTransposeMatMulBias, + ::testing::Combine( + ::testing::Values(std::vector{{1, 2, 69, 43}, {2, 49, 2, 43}, {1, 1, 69, 49}}), + ::testing::Values(1), // Transpose on second input + ::testing::ValuesIn(precisions()), + ::testing::Values(1), // Subgraph; + ::testing::Values(1), // Tokenized MatMul+Bias + ::testing::Values(CommonTestUtils::DEVICE_CPU)), + ExplicitTransposeMatMulBias::getTestCaseName); +} // namespace explicit_transpose + } // namespace } // namespace snippets } // namespace test diff --git a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp index 770dffa3c5e..3e2a0ab015e 100644 --- a/src/tests/functional/plugin/shared/include/snippets/matmul.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/matmul.hpp @@ -12,21 +12,12 @@ namespace snippets { typedef std::tuple< std::vector, // Input Shapes - ov::element::Type, // Element type + std::vector,// Input Element types size_t, // Expected num nodes size_t, // Expected num subgraphs std::string // Target Device > MatMulParams; -typedef std::tuple< - std::vector, // Input Shapes - size_t , // Transpose position - ov::element::Type, // Element type - size_t, // Expected num nodes - size_t, // Expected num subgraphs - std::string // Target Device -> TransposeMatMulParams; - class MatMul : public testing::WithParamInterface, virtual public ov::test::SnippetsTestsCommon { public: @@ -36,26 +27,16 @@ protected: void SetUp() override; }; +class MatMulFQ : public MatMul { +protected: + void SetUp() override; +}; + class MatMulBias : public MatMul { protected: void SetUp() override; }; -class ExplicitTransposeMatMul : public MatMul { -protected: - void SetUp() override; -}; - -class ExplicitTransposeMatMulBias : public MatMul { -protected: - void SetUp() override; -}; - -class ExplicitTransposeMulMatMulBias : public MatMul { -protected: - void SetUp() override; -}; - } // namespace snippets } // namespace test } // namespace ov \ No newline at end of file diff --git a/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp b/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp index f949e9df9d5..6eadc733042 100644 --- a/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp +++ b/src/tests/functional/plugin/shared/include/snippets/transpose_matmul.hpp @@ -13,7 +13,7 @@ namespace snippets { typedef std::tuple< std::vector, // Input Shapes size_t , // Transpose position - ov::element::Type, // Element type + std::vector,// Input Element types size_t, // Expected num nodes size_t, // Expected num subgraphs std::string // Target Device @@ -28,6 +28,21 @@ protected: void SetUp() override; }; +class TransposeMatMulFQ : public TransposeMatMul { +protected: + void SetUp() override; +}; + +class ExplicitTransposeMatMul : public TransposeMatMul { +protected: + void SetUp() override; +}; + +class ExplicitTransposeMatMulBias : public TransposeMatMul { +protected: + void SetUp() override; +}; + } // namespace snippets } // namespace test } // namespace ov \ No newline at end of file diff --git a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp index 1c38a168e83..06a37e2fd1f 100644 --- a/src/tests/functional/plugin/shared/src/snippets/matmul.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/matmul.cpp @@ -14,14 +14,15 @@ namespace snippets { std::string MatMul::getTestCaseName(testing::TestParamInfo obj) { std::vector input_shapes; - ov::element::Type elem_type; + std::vector elem_types; std::string targetDevice; size_t num_nodes, num_subgraphs; - std::tie(input_shapes, elem_type, num_nodes, num_subgraphs, targetDevice) = obj.param; + std::tie(input_shapes, elem_types, num_nodes, num_subgraphs, targetDevice) = obj.param; std::ostringstream result; for (size_t i = 0; i < input_shapes.size(); i++) result << "IS[" << i <<"]=" << CommonTestUtils::partialShape2str({input_shapes[i]}) << "_"; - result << "T=" << elem_type << "_"; + for (size_t i = 0; i < elem_types.size(); i++) + result << "T[" << i <<"]=" << elem_types[i] << "_"; result << "#N=" << num_nodes << "_"; result << "#S=" << num_subgraphs << "_"; result << "targetDevice=" << targetDevice; @@ -30,11 +31,25 @@ std::string MatMul::getTestCaseName(testing::TestParamInfo input_shapes; - ov::element::Type elem_type; - std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::vector elem_types; + std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - auto f = ov::test::snippets::MatMulFunction(input_shapes); + auto f = ov::test::snippets::MatMulFunction(input_shapes, elem_types); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + +void MatMulFQ::SetUp() { + std::vector input_shapes; + std::vector elem_types; + std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::FQMatMulFunction(input_shapes); function = f.getOriginal(); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, @@ -44,53 +59,11 @@ void MatMul::SetUp() { void MatMulBias::SetUp() { std::vector input_shapes; - ov::element::Type elem_type; - std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::vector elem_types; + std::tie(input_shapes, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - auto f = ov::test::snippets::MatMulBiasFunction(input_shapes); - function = f.getOriginal(); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } -} - -void ExplicitTransposeMatMul::SetUp() { - std::vector input_shapes; - ov::element::Type elem_type; - std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - - auto f = ov::test::snippets::TransposeMatMulFunction(input_shapes); - function = f.getOriginal(); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } -} - -void ExplicitTransposeMatMulBias::SetUp() { - std::vector input_shapes; - ov::element::Type elem_type; - std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - - auto f = ov::test::snippets::TransposeMatMulBiasFunction(input_shapes); - function = f.getOriginal(); - if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { - configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, - InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); - } -} - -void ExplicitTransposeMulMatMulBias::SetUp() { - std::vector input_shapes; - ov::element::Type elem_type; - std::tie(input_shapes, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); - init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - - auto f = ov::test::snippets::TransposeMulMatMulBiasFunction(input_shapes); + auto f = ov::test::snippets::MatMulBiasFunction(input_shapes, elem_types); function = f.getOriginal(); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, @@ -99,26 +72,19 @@ void ExplicitTransposeMulMatMulBias::SetUp() { } TEST_P(MatMul, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + +TEST_P(MatMulFQ, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } TEST_P(MatMulBias, CompareWithRefImpl) { - run(); - validateNumSubgraphs(); -} - -TEST_P(ExplicitTransposeMatMul, CompareWithRefImpl) { - run(); - validateNumSubgraphs(); -} - -TEST_P(ExplicitTransposeMatMulBias, CompareWithRefImpl) { - run(); - validateNumSubgraphs(); -} - -TEST_P(ExplicitTransposeMulMatMulBias, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } diff --git a/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp b/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp index 68a2140339f..f3fc23c2ce4 100644 --- a/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp +++ b/src/tests/functional/plugin/shared/src/snippets/transpose_matmul.cpp @@ -15,17 +15,17 @@ namespace snippets { std::string TransposeMatMul::getTestCaseName(testing::TestParamInfo obj) { std::vector input_shapes; size_t transpose_position; - ov::element::Type elem_type; + std::vector elem_types; std::string targetDevice; size_t num_nodes, num_subgraphs; - std::tie(input_shapes, transpose_position, elem_type, num_nodes, num_subgraphs, targetDevice) = obj.param; - if (input_shapes.size() != 2) - IE_THROW() << "Invalid input shapes vector size"; + std::tie(input_shapes, transpose_position, elem_types, num_nodes, num_subgraphs, targetDevice) = obj.param; std::ostringstream result; - result << "IS[0]=" << CommonTestUtils::partialShape2str({input_shapes[0]}) << "_"; - result << "IS[1]=" << CommonTestUtils::partialShape2str({input_shapes[1]}) << "_"; + for (size_t i = 0; i < input_shapes.size(); ++i) { + result << "IS[" << i << "]=" << CommonTestUtils::partialShape2str({input_shapes[i]}) << "_"; + } result << "Pos=" << transpose_position << "_"; - result << "T=" << elem_type << "_"; + for (size_t i = 0; i < elem_types.size(); i++) + result << "T[" << i <<"]=" << elem_types[i] << "_"; result << "#N=" << num_nodes << "_"; result << "#S=" << num_subgraphs << "_"; result << "targetDevice=" << targetDevice; @@ -35,11 +35,56 @@ std::string TransposeMatMul::getTestCaseName(testing::TestParamInfo input_shapes; size_t transpose_position; - ov::element::Type elem_type; - std::tie(input_shapes, transpose_position, elem_type, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + std::vector elem_types; + std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); - auto f = ov::test::snippets::Transpose0213MatMulFunction(input_shapes, transpose_position); + auto f = ov::test::snippets::Transpose0213MatMulFunction(input_shapes, elem_types, transpose_position); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + +void TransposeMatMulFQ::SetUp() { + std::vector input_shapes; + size_t transpose_position; + std::vector elem_types; + std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::FQMatMulFunction(input_shapes, transpose_position); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + +void ExplicitTransposeMatMul::SetUp() { + std::vector input_shapes; + size_t transpose_position; + std::vector elem_types; + std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::TransposeMatMulFunction(input_shapes); + function = f.getOriginal(); + if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { + configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, + InferenceEngine::PluginConfigInternalParams::IGNORE_CALLBACK}); + } +} + +void ExplicitTransposeMatMulBias::SetUp() { + std::vector input_shapes; + size_t transpose_position; + std::vector elem_types; + std::tie(input_shapes, transpose_position, elem_types, ref_num_nodes, ref_num_subgraphs, targetDevice) = this->GetParam(); + init_input_shapes(static_partial_shapes_to_test_representation(input_shapes)); + + auto f = ov::test::snippets::TransposeMatMulBiasFunction(input_shapes); function = f.getOriginal(); if (!configuration.count(InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE)) { configuration.insert({InferenceEngine::PluginConfigInternalParams::KEY_SNIPPETS_MODE, @@ -48,6 +93,25 @@ void TransposeMatMul::SetUp() { } TEST_P(TransposeMatMul, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + +TEST_P(TransposeMatMulFQ, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + +TEST_P(ExplicitTransposeMatMul, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() + run(); + validateNumSubgraphs(); +} + +TEST_P(ExplicitTransposeMatMulBias, CompareWithRefImpl) { + SKIP_IF_CURRENT_TEST_IS_DISABLED() run(); validateNumSubgraphs(); } diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp index c583b5882ab..40f8c20c9f3 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_lowered.hpp @@ -56,7 +56,7 @@ private: class Transpose0213MatMulLoweredFunction : public Transpose0213MatMulFunction { public: explicit Transpose0213MatMulLoweredFunction(const std::vector& inputShapes, size_t position = 0) : - Transpose0213MatMulFunction(inputShapes, position) { + Transpose0213MatMulFunction(inputShapes, std::vector{ov::element::f32, ov::element::f32}, position) { } protected: std::shared_ptr initLowered() const override; diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp index ea533334e80..15954605e69 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/include/subgraph_matmul.hpp @@ -6,6 +6,7 @@ #include "ngraph/ngraph.hpp" #include "./snippets_helpers.hpp" +#include "snippets/utils.hpp" /* This file contains definitions of relatively simple functions (models) that will be used * to test snippets-specific behavior. All the functions are expected to be direct descendants of @@ -20,48 +21,77 @@ namespace snippets { // in1 in2 // Matmul // Result -// todo: remove once "no subgraph after input" limitation is relaxed class MatMulFunction : public SnippetsFunctionBase { public: - explicit MatMulFunction(const std::vector& inputShapes) - : SnippetsFunctionBase(inputShapes) { + explicit MatMulFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes"); + verify_precisions(precisions); + } + static void verify_precisions(const std::vector& precisions) { + NGRAPH_CHECK(precisions.size() == 2, "Got invalid number of input element types"); + const bool is_f32 = ngraph::snippets::utils::everyone_is(element::f32, precisions[0], precisions[1]); + const bool is_int8 = ngraph::snippets::utils::one_of(precisions[0], element::i8, element::u8) && precisions[1] == element::i8; + const bool is_bf16 = ngraph::snippets::utils::everyone_is(element::bf16, precisions[0], precisions[1]); + NGRAPH_CHECK(is_f32 || is_bf16 || is_int8, "Invalid precisions"); } protected: std::shared_ptr initOriginal() const override; std::shared_ptr initReference() const override; + + std::vector precisions; +}; + +class FQMatMulFunction : public SnippetsFunctionBase { +public: + explicit FQMatMulFunction(const std::vector& inputShapes, int pos = -1) : SnippetsFunctionBase({inputShapes[0]}), pos(pos) { + NGRAPH_CHECK(inputShapes.size() == 2, "Got invalid number of input shapes"); + NGRAPH_CHECK(pos >=-1 && pos <= 2, "Got invalid transpose position"); + const_shape = inputShapes[1]; + } +protected: + std::shared_ptr initOriginal() const override; + + ov::PartialShape const_shape; + int pos = -1; }; // As same as MatMulFunction but with biases class MatMulBiasFunction : public SnippetsFunctionBase { public: - explicit MatMulBiasFunction(const std::vector& inputShapes) - : SnippetsFunctionBase(inputShapes) { + explicit MatMulBiasFunction(const std::vector& inputShapes, const std::vector& precisions) + : SnippetsFunctionBase(inputShapes), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 3, "Got invalid number of input shapes"); + MatMulFunction::verify_precisions(precisions); } protected: std::shared_ptr initOriginal() const override; + + std::vector precisions; }; /// Minimal graph to test MatMul+Transpose combinations. Transpose location is specified via the position argument: /// 0 - before the first MatMul input; 1 - before the second MatMul input; 2 - after the MatMul output. /// Tokenized simply by starting subgraph, // in1 in2 -// Transpose / +// Transpose / // Matmul // Result class Transpose0213MatMulFunction : public SnippetsFunctionBase { public: - explicit Transpose0213MatMulFunction(const std::vector& inputShapes, size_t position = 0) - : SnippetsFunctionBase(inputShapes), transpose_position(position) { + explicit Transpose0213MatMulFunction(const std::vector& inputShapes, const std::vector& precisions, + size_t position = 0) + : SnippetsFunctionBase(inputShapes), transpose_position(position), precisions(precisions) { NGRAPH_CHECK(input_shapes.size() == 2, "Got invalid number of input shapes"); NGRAPH_CHECK(input_shapes[0].rank().get_length() == 4 && input_shapes[1].rank().get_length() == 4, "Only rank 4 input shapes are supported by this test"); NGRAPH_CHECK(transpose_position >=0 && transpose_position <= 2, "Got invalid transpose position"); + MatMulFunction::verify_precisions(precisions); } protected: std::shared_ptr initOriginal() const override; size_t transpose_position; + std::vector precisions; }; class TransposeMatMulFunction : public SnippetsFunctionBase { diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp index 22b86982e9e..6c818b6078c 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_lowered.cpp @@ -107,8 +107,8 @@ std::shared_ptr EltwiseThreeInputsLoweredFunction::initLowered() cons } std::shared_ptr Transpose0213MatMulLoweredFunction::initLowered() const { - ParameterVector data{std::make_shared(precision, input_shapes[0]), - std::make_shared(precision, input_shapes[1])}; + ParameterVector data{std::make_shared(precisions[0], input_shapes[0]), + std::make_shared(precisions[1], input_shapes[1])}; std::vector layout{0, 2, 1, 3}; // Note: validity of transpose_position values is checked in Transpose0213MatMulSinhFunction constructor if (transpose_position <= 1) { @@ -194,6 +194,7 @@ std::shared_ptr SoftmaxLoweredFunction::initLowered() const { const auto horizon_sum = std::make_shared(sum); horizon_sum->add_control_dependency(loop_sum_end); + const auto size_exp = std::make_shared(ov::element::i32, ov::Shape{2}); const auto buffer_exp = std::make_shared(loop_sum_end->output(0)); loop_sum_begin->add_control_dependency(vector_buffer_sum); @@ -303,6 +304,7 @@ std::shared_ptr AddSoftmaxLoweredFunction::initLowered() const { /* =========================================== */ + const auto size_add = std::make_shared(ov::element::i32, ov::Shape{2}); const auto buffer_add = std::make_shared(loop_max_end->output(0)); /* === Sub + Exp + ReduceSum decomposition === */ @@ -331,6 +333,7 @@ std::shared_ptr AddSoftmaxLoweredFunction::initLowered() const { const auto horizon_sum = std::make_shared(sum); horizon_sum->add_control_dependency(loop_sum_end); + const auto size_exp = std::make_shared(ov::element::i32, ov::Shape{2}); const auto buffer_exp = std::make_shared(loop_sum_end->output(0)); loop_sum_begin->add_control_dependency(vector_buffer_sum); diff --git a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp index af312a2ee2d..b213c66ecca 100644 --- a/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp +++ b/src/tests/ngraph_helpers/snippets_ngraph_functions/src/subgraph_matmul.cpp @@ -5,50 +5,133 @@ #include "subgraph_matmul.hpp" #include "common_test_utils/data_utils.hpp" #include +#include "ngraph_functions/builders.hpp" +#include "ov_ops/type_relaxed.hpp" + namespace ov { namespace test { namespace snippets { std::shared_ptr MatMulFunction::initOriginal() const { - auto data0 = std::make_shared(precision, input_shapes[0]); - auto data1 = std::make_shared(precision, input_shapes[1]); - auto matmul = std::make_shared(data0, data1); + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); + std::shared_ptr matmul; + if (precisions[1] == ov::element::i8) { + matmul = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(data0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(data1, element::f32).get()); + } else { + matmul = std::make_shared(data0, data1); + } return std::make_shared(NodeVector{matmul}, ParameterVector{data0, data1}); } std::shared_ptr MatMulFunction::initReference() const { + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); + auto indata0 = std::make_shared(precisions[0], data0->get_output_partial_shape(0)); + auto indata1 = std::make_shared(precisions[1], data1->get_output_partial_shape(0)); + std::shared_ptr matmul; + if (precisions[1] == ov::element::i8) { + matmul = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(indata0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(indata1, element::f32).get()); + } else { + matmul = std::make_shared(indata0, indata1); + } + const auto subgraph = std::make_shared(NodeVector{data0, data1}, + std::make_shared(NodeVector{matmul}, + ParameterVector{indata0, indata1})); + return std::make_shared(NodeVector{subgraph}, ParameterVector{data0, data1}); +} +std::shared_ptr FQMatMulFunction::initOriginal() const { + auto const_order = std::make_shared(ov::element::i32, Shape {4}, std::vector{0, 2, 1, 3}); auto data0 = std::make_shared(precision, input_shapes[0]); - auto data1 = std::make_shared(precision, input_shapes[1]); - auto indata0 = std::make_shared(precision, data0->get_output_partial_shape(0)); - auto indata1 = std::make_shared(precision, data1->get_output_partial_shape(0)); - auto matmul = std::make_shared(NodeVector{data0, data1}, - std::make_shared(NodeVector{std::make_shared(indata0, indata1)}, - ParameterVector{indata0, indata1})); - return std::make_shared(NodeVector{matmul}, ParameterVector{data0, data1}); + auto ih = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{34.7436294}); + auto il = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{-35.0172004}); + auto oh = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{34.7436294}); + auto ol = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{-35.0172004}); + auto fq = std::make_shared(data0, il, ih, ol, oh, 256); + std::shared_ptr in0 = fq; + if (pos == 0) { + in0 = std::make_shared(in0, const_order); + } + auto constant = ngraph::builder::makeConstant(ov::element::i8, const_shape.get_shape(), std::vector{}, true); + auto convert = std::make_shared(constant, ov::element::f32); + auto deq_mul = std::make_shared(ov::element::f32, ov::Shape{1}, std::vector{0.00499185826}); + auto mul = std::make_shared(convert, deq_mul); + std::shared_ptr in1 = mul; + if (pos == 1) { + in1 = std::make_shared(in1, const_order); + } + auto matmul = std::make_shared(in0, in1); + std::shared_ptr out = matmul; + if (pos == 2) { + out = std::make_shared(out, const_order); + } + return std::make_shared(NodeVector{out}, ParameterVector{data0}); } std::shared_ptr MatMulBiasFunction::initOriginal() const { auto data0 = std::make_shared(precision, input_shapes[0]); auto data1 = std::make_shared(precision, input_shapes[1]); - auto matmul = std::make_shared(data0, data1); auto data2 = std::make_shared(precision, input_shapes[2]); + std::shared_ptr matmul; + if (precisions[1] == ov::element::i8) { + matmul = std::make_shared>( + std::vector{ element::f32, element::f32 }, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(data0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(data1, element::f32).get()); + } else { + matmul = std::make_shared(data0, data1); + } auto bias = std::make_shared(matmul, data2); return std::make_shared(NodeVector{bias}, ParameterVector{data0, data1, data2}); } std::shared_ptr Transpose0213MatMulFunction::initOriginal() const { - auto data0 = std::make_shared(precision, input_shapes[0]); - auto data1 = std::make_shared(precision, input_shapes[1]); + auto data0 = std::make_shared(precisions[0], input_shapes[0]); + auto data1 = std::make_shared(precisions[1], input_shapes[1]); auto const_order = std::make_shared(ov::element::i32, Shape {4}, std::vector{0, 2, 1, 3}); std::shared_ptr result; switch (transpose_position) { case 0: { auto transpose = std::make_shared(data0, const_order); - result = std::make_shared(transpose, data1); + if (precisions[1] == ov::element::i8) { + result = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(transpose, element::f32).get(), + ov::op::TemporaryReplaceOutputType(data1, element::f32).get()); + } else { + result = std::make_shared(transpose, data1); + } break; } case 1: { auto transpose = std::make_shared(data1, const_order); - result = std::make_shared(data0, transpose); + if (precisions[1] == ov::element::i8) { + result = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(data0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(transpose, element::f32).get()); + } else { + result = std::make_shared(data0, transpose); + } break; } case 2: { - auto matmul = std::make_shared(data0, data1); + std::shared_ptr matmul; + if (precisions[1] == ov::element::i8) { + matmul = std::make_shared>( + std::vector{element::f32, element::f32}, + std::vector{ element::f32 }, + ov::op::TemporaryReplaceOutputType(data0, element::f32).get(), + ov::op::TemporaryReplaceOutputType(data1, element::f32).get()); + } else { + matmul = std::make_shared(data0, data1); + } result = std::make_shared(matmul, const_order); break; }