diff --git a/src/plugins/intel_cpu/src/cpu_memory.cpp b/src/plugins/intel_cpu/src/cpu_memory.cpp index 8d2fb700494..19efccff0f8 100644 --- a/src/plugins/intel_cpu/src/cpu_memory.cpp +++ b/src/plugins/intel_cpu/src/cpu_memory.cpp @@ -245,6 +245,49 @@ void MemoryMngrWithReuse::destroy(void *ptr) { dnnl::impl::free(ptr); } +void* MemoryMngrRealloc::getRawPtr() const noexcept { + return m_data.get(); +} + +void MemoryMngrRealloc::setExtBuff(void *ptr, size_t size) { + m_useExternalStorage = true; + m_memUpperBound = size; + m_data = decltype(m_data)(ptr, release); +} + +bool MemoryMngrRealloc::resize(size_t size) { + constexpr int cacheLineSize = 64; + constexpr size_t growFactor = 2; + bool sizeChanged = false; + if (size > m_memUpperBound) { + size *= growFactor; + void *ptr = dnnl::impl::malloc(size, cacheLineSize); + if (!ptr) { + OPENVINO_THROW("Failed to allocate ", size, " bytes of memory"); + } + + if (auto src = m_data.get()) { + std::memcpy(ptr, src, m_memUpperBound); + } + + m_memUpperBound = size; + m_useExternalStorage = false; + m_data = decltype(m_data)(ptr, destroy); + sizeChanged = true; + } + return sizeChanged; +} + +bool MemoryMngrRealloc::hasExtBuffer() const noexcept { + return m_useExternalStorage; +} + +void MemoryMngrRealloc::release(void *ptr) {} + +void MemoryMngrRealloc::destroy(void *ptr) { + dnnl::impl::free(ptr); +} + void* DnnlMemoryMngr::getRawPtr() const noexcept { return m_pMemMngr->getRawPtr(); } diff --git a/src/plugins/intel_cpu/src/cpu_memory.h b/src/plugins/intel_cpu/src/cpu_memory.h index fefe011a85a..4b3b8c5ca30 100644 --- a/src/plugins/intel_cpu/src/cpu_memory.h +++ b/src/plugins/intel_cpu/src/cpu_memory.h @@ -89,6 +89,23 @@ private: static void destroy(void *ptr); }; +class MemoryMngrRealloc : public IMemoryMngr { +public: + MemoryMngrRealloc() : m_data(nullptr, release) {} + void* getRawPtr() const noexcept override; + void setExtBuff(void* ptr, size_t size) override; + bool resize(size_t size) override; + bool hasExtBuffer() const noexcept override; + +private: + bool m_useExternalStorage = false; + size_t m_memUpperBound = 0ul; + std::unique_ptr m_data; + + static void release(void *ptr); + static void destroy(void *ptr); +}; + class IMemoryMngrObserver : public IMemoryMngr { public: virtual void registerMemory(Memory* memPtr) = 0; diff --git a/src/plugins/intel_cpu/src/cpu_types.cpp b/src/plugins/intel_cpu/src/cpu_types.cpp index a4e8b140415..f897ca808c1 100644 --- a/src/plugins/intel_cpu/src/cpu_types.cpp +++ b/src/plugins/intel_cpu/src/cpu_types.cpp @@ -214,6 +214,7 @@ static const TypeToNameMap& get_type_to_name_tbl() { { "Unique", Type::Unique}, { "Ngram", Type::Ngram}, { "ScaledDotProductAttention", Type::ScaledDotProductAttention}, + { "ScaledDotProductAttentionStub", Type::ScaledDotProductAttention}, { "RoPE", Type::RoPE}, }; return type_to_name_tbl; diff --git a/src/plugins/intel_cpu/src/extension.cpp b/src/plugins/intel_cpu/src/extension.cpp index ad1c8787f55..8110a4c2fff 100644 --- a/src/plugins/intel_cpu/src/extension.cpp +++ b/src/plugins/intel_cpu/src/extension.cpp @@ -6,6 +6,7 @@ #include "transformations/cpu_opset/common/op/fully_connected.hpp" #include "transformations/cpu_opset/common/op/leaky_relu.hpp" #include "transformations/cpu_opset/common/op/power_static.hpp" +#include "transformations/cpu_opset/common/op/sdp.hpp" #include "transformations/cpu_opset/common/op/swish_cpu.hpp" #include "transformations/cpu_opset/common/op/ngram.hpp" #include "transformations/cpu_opset/x64/op/mha.hpp" @@ -59,6 +60,7 @@ std::map Extension::getOpSets() { NGRAPH_OP(NgramNode, ov::intel_cpu) NGRAPH_OP_X64(MHANode, ov::intel_cpu) NGRAPH_OP_X64(InteractionNode, ov::intel_cpu) + NGRAPH_OP_X64(ScaledDotProductAttentionStub, ov::intel_cpu) #undef NGRAPH_OP return opset; diff --git a/src/plugins/intel_cpu/src/graph.cpp b/src/plugins/intel_cpu/src/graph.cpp index d6211c9022d..2767b053bdb 100644 --- a/src/plugins/intel_cpu/src/graph.cpp +++ b/src/plugins/intel_cpu/src/graph.cpp @@ -1441,7 +1441,7 @@ void Graph::GetPerfData(std::vector& perfMap) const { } } -void Graph::RemoveEdge(EdgePtr& edge) { +void Graph::RemoveEdge(const EdgePtr& edge) { for (auto it = graphEdges.begin(); it != graphEdges.end(); it++) { if ((*it) == edge) { edge->drop(); @@ -1881,9 +1881,9 @@ void Graph::resolveInPlaceDirection(const NodePtr& node) const { void Graph::SearchInternalStateNodes() { for (auto&& node : graphNodes) { if (node->getType() == Type::MemoryInput) { - auto cur_node = std::dynamic_pointer_cast(node); + auto cur_node = std::dynamic_pointer_cast(node); if (!cur_node) { - OPENVINO_THROW("Cannot cast ", node->getName(), " to MemoryInput"); + OPENVINO_THROW("Cannot cast ", node->getName(), " to MemoryStateNode"); } internalStateNodes.insert({cur_node->getId(), cur_node}); } diff --git a/src/plugins/intel_cpu/src/graph.h b/src/plugins/intel_cpu/src/graph.h index 3456085d3f8..890b9de8bcf 100644 --- a/src/plugins/intel_cpu/src/graph.h +++ b/src/plugins/intel_cpu/src/graph.h @@ -28,7 +28,7 @@ namespace intel_cpu { class SyncInferRequest; namespace node { -class MemoryNode; +class MemoryStateNode; } // namespace node class Graph { @@ -123,7 +123,7 @@ public: void RemoveDroppedNodes(); void RemoveDroppedEdges(); - void RemoveEdge(EdgePtr& edge); + void RemoveEdge(const EdgePtr& edge); void DropNode(const NodePtr& node); void DropDWConvNode(const NodePtr& node); @@ -197,7 +197,7 @@ public: } Status getStatus() const {return status;} - const std::unordered_map>& + const std::unordered_map>& getInternalStateNodes() const { return internalStateNodes; } @@ -259,7 +259,7 @@ private: std::map outputNodesMap; std::unordered_map outputNodesMemMngrMap; - std::unordered_map> internalStateNodes; + std::unordered_map> internalStateNodes; // these node pointers (from graphNodes) are to avoid regular checking for // constantness of nodes in Infer methods and calls of diff --git a/src/plugins/intel_cpu/src/graph_optimizer.cpp b/src/plugins/intel_cpu/src/graph_optimizer.cpp index e6f5ce0ff35..c79f0ed2aa6 100644 --- a/src/plugins/intel_cpu/src/graph_optimizer.cpp +++ b/src/plugins/intel_cpu/src/graph_optimizer.cpp @@ -22,6 +22,7 @@ #include "nodes/reduce.h" #include "nodes/input.h" #include "nodes/rnn.h" +#include "nodes/memory.hpp" #include "nodes/common/cpu_convert.h" #include "onednn/dnnl.h" @@ -182,6 +183,18 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) { RemoveSameConvert(graph); graph.RemoveDroppedNodes(); + OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveMemoryInputConvert"); + RemoveMemoryInputConvert(graph); + graph.RemoveDroppedNodes(); + + OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveConvertMemoryOutput"); + RemoveConvertMemoryOutput(graph); + graph.RemoveDroppedNodes(); + + OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "MatchSdpaKvCache"); + MatchSdpaKvCache(graph); + graph.RemoveDroppedNodes(); + OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveDroppedEdges"); graph.RemoveDroppedEdges(); } @@ -2710,5 +2723,153 @@ void GraphOptimizer::RemoveSameConvert(Graph& graph) { } } +void GraphOptimizer::RemoveMemoryInputConvert(Graph &graph) { + auto& graphNodes = graph.GetNodes(); + + auto isSuitableNode = [](const NodePtr& node) { + if (Type::Convert != node->getType()) { + return false; + } + + auto parent = node->getParentEdgeAt(0)->getParent(); + if (Type::MemoryInput != parent->getType()) { + return false; + } + + return true; + }; + + for (size_t i = 0; i < graphNodes.size(); i++) { + auto node = graphNodes[i]; + + if (!isSuitableNode(node)) { + continue; + } + graph.DropNode(node); + } +} + +void GraphOptimizer::RemoveConvertMemoryOutput(Graph &graph) { + auto& graphNodes = graph.GetNodes(); + + auto isSuitableNode = [](const NodePtr& node) { + if (Type::Convert != node->getType()) { + return false; + } + + auto&& childEdges = node->getChildEdgesAtPort(0); + for (auto&& edge : childEdges) { + if (Type::MemoryOutput != edge->getChild()->getType()) { + return false; + } + } + + return true; + }; + + for (size_t i = 0; i < graphNodes.size(); i++) { + auto node = graphNodes[i]; + + if (!isSuitableNode(node)) { + continue; + } + graph.DropNode(node); + } +} + +void GraphOptimizer::MatchSdpaKvCache(Graph &graph) { + auto& graphNodes = graph.GetNodes(); + + auto isSuitableMemInput = [](const NodePtr& node) -> bool { + if (Type::MemoryInput != node->getType()) { + return false; + } + NodePtr childSdpa = nullptr; + auto&& childEdges = node->getChildEdgesAtPort(0); + for (auto&& item : childEdges) { + auto childNode = item->getChild(); + if (!one_of(childNode->getType(), Type::ScaledDotProductAttention, Type::ShapeOf)) { + return false; + } + + if (Type::ScaledDotProductAttention == childNode->getType()) { + if (childSdpa && childSdpa != childNode) { + //only one child SDPA supported + return false; + } + childSdpa = childNode; + } + } + + CPU_GRAPH_OPTIMIZER_SCOPE(MatchSdpaKvCache_isSuitableMemInput); + + auto memInputNode = std::dynamic_pointer_cast(node); + OPENVINO_ASSERT(memInputNode, "MemoryInput node ", node->getName(), " has unexpected dynamic type"); + auto& memOutputNode = memInputNode->getOutputNode(); + auto memOutputParent = memOutputNode.getParentEdgeAt(0)->getParent(); + if (memOutputParent != childSdpa) { + return false; + } + return true; + }; + + for (size_t i = 0; i < graphNodes.size(); i++) { + auto node = graphNodes[i]; + if (!isSuitableMemInput(node)) { + continue; + } + + CPU_GRAPH_OPTIMIZER_SCOPE(MatchSdpaKvCache_Node); + + // Node is already modified + if (auto sdpaMemInput = std::dynamic_pointer_cast(node)) { + continue; + } + + auto memInputNode = std::dynamic_pointer_cast(node); + OPENVINO_ASSERT(memInputNode, "MemoryInput node ", node->getName(), " has unexpected dynamic type"); + + ov::optional input_shape; + ov::optional input_prc; + + if (!node->getParentEdges().empty()) { + input_shape = ov::optional(node->getInputShapeAtPort(0)); + input_prc = ov::optional(node->getOriginalInputPrecisionAtPort(0)); + } + + auto memInputSdpa = std::make_shared( + memInputNode->getId(), + memInputNode->getName(), + memInputNode->getTypeStr(), + memInputNode->getOutputShapeAtPort(0), + memInputNode->getOriginalOutputPrecisionAtPort(0), + graph.getGraphContext(), + input_shape, + input_prc); + + if (!memInputNode->getParentEdges().empty()) { + auto parentEdge = memInputNode->getParentEdgeAt(0); + auto newEdge = std::make_shared(parentEdge->getParent(), memInputSdpa, parentEdge->getInputNum(), 0); + memInputSdpa->addEdge(newEdge); + graph.GetEdges().push_back(newEdge); + graph.RemoveEdge(parentEdge); + } + + for (auto&& edge : memInputNode->getChildEdgesAtPort(0)) { + auto newEdge = std::make_shared(memInputSdpa, edge->getChild(), 0, edge->getOutputNum()); + memInputSdpa->addEdge(newEdge); + graph.GetEdges().push_back(newEdge); + graph.RemoveEdge(edge); + } + + //link with memory output + auto& memOutput = memInputNode->getOutputNode(); + memInputSdpa->registerOutputNode(&memOutput); + + graph.GetNodes().push_back(memInputSdpa); + graph.DropNode(memInputNode); + } +} + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/graph_optimizer.h b/src/plugins/intel_cpu/src/graph_optimizer.h index bb6494758f3..48940adaae4 100644 --- a/src/plugins/intel_cpu/src/graph_optimizer.h +++ b/src/plugins/intel_cpu/src/graph_optimizer.h @@ -49,6 +49,9 @@ private: void MergeTransposeAndReorder(Graph &graph); void reshapeRnnSeq(Graph &graph); void RemoveSameConvert(Graph &graph); + void RemoveMemoryInputConvert(Graph &graph); + void RemoveConvertMemoryOutput(Graph &graph); + void MatchSdpaKvCache(Graph &graph); }; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/infer_request.cpp b/src/plugins/intel_cpu/src/infer_request.cpp index d0da824e0d4..2ca2913b7cc 100644 --- a/src/plugins/intel_cpu/src/infer_request.cpp +++ b/src/plugins/intel_cpu/src/infer_request.cpp @@ -61,7 +61,7 @@ void SyncInferRequest::create_infer_request() { init_tensor(it.first); } - //create states according to the list of the MemoryNodes + //create states according to the list of the MemoryStateNodes for (auto&& node : m_graph->getInternalStateNodes()) { m_memory_states.emplace_back(node.second->makeState()); } diff --git a/src/plugins/intel_cpu/src/memory_state.cpp b/src/plugins/intel_cpu/src/memory_state.cpp index fc74e6094ed..1594ebb6da4 100644 --- a/src/plugins/intel_cpu/src/memory_state.cpp +++ b/src/plugins/intel_cpu/src/memory_state.cpp @@ -14,13 +14,81 @@ using namespace InferenceEngine; namespace ov { namespace intel_cpu { -VariableStateDoubleBuffer::VariableStateDoubleBuffer(std::string name, - const MemBuilder& mem_build, - MemoryDescPtr external_desc, - MemoryCPtr init_val) : - IVariableState{name}, m_external_desc{external_desc} { - reset_prime_mem(mem_build()); - reset_second_mem(mem_build()); +VariableStateBase::VariableStateBase(const std::string& name, const MemoryDescPtr& external_desc) : + IVariableState{name} , m_external_desc{external_desc} {} + +MemoryDescPtr VariableStateBase::to_static(const MemoryDescPtr& desc) { + if (!desc->isDefined()) { + auto&& current_dims = desc->getShape().getDims(); + VectorDims new_dims(current_dims.size()); + std::transform(current_dims.begin(), current_dims.end(), new_dims.begin(), [](Dim x) { + return x == Shape::UNDEFINED_DIM ? 0 : x; }); + + return desc->cloneWithNewDims(new_dims, true); + } + return desc; +} + +const dnnl::engine& VariableStateBase::get_engine() { + static const dnnl::engine eng(dnnl::engine::kind::cpu, 0); + return eng; +} + +void VariableStateBase::set_state(const ov::SoPtr& state) { + m_state = state; // simply to extend the lifetime + auto state_desc = MemoryDescUtils::generateCpuBlockedMemoryDesc(m_state); + + const auto& shape = state_desc->getShape(); + + if (input_mem()->getShape() != shape) { + auto new_desc = internal_desc()->cloneWithNewDims(shape.getStaticDims()); + input_mem()->redefineDesc(new_desc); + } + + auto src = m_state->data(); + + Memory mem(get_engine(), state_desc, src); + input_mem()->load(mem); +} + +ov::SoPtr VariableStateBase::get_state() const { + const auto& current_dims = internal_state_mem()->getStaticDims(); + auto current_ext_desc = m_external_desc->cloneWithNewDims(current_dims); + auto current_internal_desc = internal_state_mem()->getDescPtr(); + + if (current_ext_desc->isCompatible(*current_internal_desc)) { + return std::make_shared(internal_state_mem()); + } + + //test precision + { + auto internal_prc = current_internal_desc->getPrecision(); + auto tmp_desc = current_ext_desc->cloneWithNewPrecision(internal_prc); + if (tmp_desc->isCompatible(*current_internal_desc)) { + auto mem = std::make_shared(get_engine(), current_ext_desc); + size_t elements_to_convert = internal_state_mem()->getDescWithType()->getPaddedElementsCount(); + auto external_prc = current_ext_desc->getPrecision(); + + cpu_convert(internal_state_mem()->getData(), mem->getData(), internal_prc, external_prc, elements_to_convert); + return std::make_shared(mem); + } + } + + //reorder + auto mem = std::make_shared(get_engine(), current_ext_desc); + mem->load(*(internal_state_mem())); + return std::make_shared(mem); +} + +VariableStateDoubleBuffer::VariableStateDoubleBuffer(const std::string& name, + const MemoryPtr& first_buffer, + const MemoryPtr& second_buffer, + const MemoryDescPtr& external_desc, + const MemoryCPtr& init_val) : + VariableStateBase(name, external_desc) { + OPENVINO_ASSERT(first_buffer && second_buffer); + reset_prime_mem(first_buffer); + reset_second_mem(second_buffer); m_internal_desc = prime_mem()->getDescPtr(); auto&& shape = m_internal_desc->getShape(); //TODO what if by some reason we already have internal static state while the node is dynamic, is it even possible? @@ -38,58 +106,6 @@ VariableStateDoubleBuffer::VariableStateDoubleBuffer(std::string name, } } -void VariableStateDoubleBuffer::set_state(const ov::SoPtr& state) { - m_state = state; // simply to extend the lifetime - auto state_desc = MemoryDescUtils::generateCpuBlockedMemoryDesc(m_state); - - const auto& shape = state_desc->getShape(); - - if (prime_mem()->getShape() != shape) { - auto new_desc = m_internal_desc->cloneWithNewDims(shape.getStaticDims()); - prime_mem()->redefineDesc(new_desc); - } - - auto src = m_state->data(); - - Memory mem(get_engine(), state_desc, src); - prime_mem()->load(mem); -} - -const dnnl::engine& VariableStateDoubleBuffer::get_engine() const { - static const dnnl::engine eng(dnnl::engine::kind::cpu, 0); - return eng; -} - -ov::SoPtr VariableStateDoubleBuffer::get_state() const { - //TODO , in general case must be synchronized - const auto& current_dims = prime_mem()->getStaticDims(); - auto current_ext_desc = m_external_desc->cloneWithNewDims(current_dims); - auto current_internal_desc = prime_mem()->getDescPtr(); - - if (current_ext_desc->isCompatible(*current_internal_desc)) { - return std::make_shared(prime_mem()); - } - - //test precision - { - auto internal_prc = current_internal_desc->getPrecision(); - auto tmp_desc = current_ext_desc->cloneWithNewPrecision(internal_prc); - if (tmp_desc->isCompatible(*current_internal_desc)) { - auto mem = std::make_shared(get_engine(), current_ext_desc); - size_t elements_to_convert = prime_mem()->getDescWithType()->getPaddedElementsCount(); - auto external_prc = current_ext_desc->getPrecision(); - - cpu_convert(prime_mem()->getData(), mem->getData(), internal_prc, external_prc, elements_to_convert); - return std::make_shared(mem); - } - } - - //reorder - auto mem = std::make_shared(get_engine(), current_ext_desc); - mem->load(*(prime_mem())); - return std::make_shared(mem); -} - void VariableStateDoubleBuffer::reset() { auto new_desc = to_static(m_internal_desc); for (auto&& mem : m_internal_mem) { @@ -100,18 +116,6 @@ void VariableStateDoubleBuffer::reset() { } } -MemoryDescPtr VariableStateDoubleBuffer::to_static(const MemoryDescPtr& desc) { - if (!desc->isDefined()) { - auto&& current_dims = desc->getShape().getDims(); - VectorDims new_dims(current_dims.size()); - std::transform(current_dims.begin(), current_dims.end(), new_dims.begin(), [](Dim x) { - return x == Shape::UNDEFINED_DIM ? 0 : x; }); - - return desc->cloneWithNewDims(new_dims, true); - } - return desc; -} - void VariableStateDoubleBuffer::commit() { buffer_num ^= 0x01; } @@ -128,5 +132,59 @@ MemoryDescPtr VariableStateDoubleBuffer::internal_desc() const { return m_internal_desc; } +MemoryPtr VariableStateDoubleBuffer::internal_state_mem() const { + return prime_mem(); +} + +VariableStateSingleBuffer::VariableStateSingleBuffer(const std::string& name, + const MemoryPtr& buffer, + const MemoryDescPtr& external_desc, + const MemoryCPtr& init_val) : + VariableStateBase(name, external_desc) { + OPENVINO_ASSERT(buffer); + m_internal_mem = buffer; + m_internal_desc = m_internal_mem->getDescPtr(); + auto&& shape = m_internal_desc->getShape(); + //TODO what if by some reason we already have internal static state while the node is dynamic, is it even possible? + + if (shape.isStatic()) { + if (init_val) { + m_internal_mem->load(*init_val); + } else { + m_internal_mem->nullify(); + } + } else { + //in the case of the original desc has dynamic shape we create an empty tensor + auto new_desc = to_static(m_internal_desc); + m_internal_mem->redefineDesc(new_desc); + } +} + +void VariableStateSingleBuffer::reset() { + auto new_desc = to_static(m_internal_desc); + m_internal_mem->redefineDesc(new_desc); + m_internal_mem->nullify(); +} + +MemoryPtr VariableStateSingleBuffer::input_mem() { + return m_internal_mem; +} + +MemoryPtr VariableStateSingleBuffer::output_mem() { + return m_internal_mem; +} + +MemoryDescPtr VariableStateSingleBuffer::internal_desc() const { + return m_internal_desc; +} + +MemoryPtr VariableStateSingleBuffer::internal_state_mem() const { + return m_internal_mem; +} + +void VariableStateSingleBuffer::commit() { + //nothing to do +} + } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/memory_state.h b/src/plugins/intel_cpu/src/memory_state.h index 43d6206c56f..a5e4746435f 100644 --- a/src/plugins/intel_cpu/src/memory_state.h +++ b/src/plugins/intel_cpu/src/memory_state.h @@ -28,20 +28,35 @@ public: virtual MemoryDescPtr internal_desc() const = 0; }; -class VariableStateDoubleBuffer : public IVariableState { +class VariableStateBase : public IVariableState { public: - using MemBuilder = std::function; + VariableStateBase(const std::string& name, const MemoryDescPtr& external_desc); -public: - VariableStateDoubleBuffer(std::string name, - const MemBuilder& mem_build, - MemoryDescPtr external_desc, - MemoryCPtr init_val); //ov::IVariableState - void reset() override; void set_state(const ov::SoPtr& state) override; ov::SoPtr get_state() const override; +protected: + virtual MemoryPtr internal_state_mem() const = 0; + + static MemoryDescPtr to_static(const MemoryDescPtr& desc); + static const dnnl::engine& get_engine(); + +protected: + MemoryDescPtr m_external_desc; +}; + +class VariableStateDoubleBuffer : public VariableStateBase { +public: + VariableStateDoubleBuffer(const std::string& name, + const MemoryPtr& first_buffer, + const MemoryPtr& second_buffer, + const MemoryDescPtr& external_desc, + const MemoryCPtr& init_val); + + //ov::IVariableState + void reset() override; + //ov::intel_cpu::IVariableState void commit() override; @@ -50,8 +65,6 @@ public: MemoryDescPtr internal_desc() const override; private: - static MemoryDescPtr to_static(const MemoryDescPtr& desc); - void reset_prime_mem(const MemoryPtr& mem) { m_internal_mem[buffer_num] = mem; } @@ -68,16 +81,38 @@ private: return m_internal_mem[buffer_num ^ 0x1]; } - - const dnnl::engine& get_engine() const; + MemoryPtr internal_state_mem() const override; private: - MemoryDescPtr m_external_desc; MemoryDescPtr m_internal_desc; //mem desc required by the graph internal tensor std::array m_internal_mem{}; size_t buffer_num = 0; }; +class VariableStateSingleBuffer : public VariableStateBase { +public: + VariableStateSingleBuffer(const std::string& name, + const MemoryPtr& buffer, + const MemoryDescPtr& external_desc, + const MemoryCPtr& init_val); + //ov::IVariableState + void reset() override; + + //ov::intel_cpu::IVariableState + void commit() override; + + MemoryPtr input_mem() override; + MemoryPtr output_mem() override; + MemoryDescPtr internal_desc() const override; + +private: + MemoryPtr internal_state_mem() const override; + +private: + MemoryDescPtr m_internal_desc; //mem desc required by the graph internal tensor + MemoryPtr m_internal_mem; +}; + using MemStatePtr = std::shared_ptr; using MemStateCPtr = std::shared_ptr; } // namespace intel_cpu diff --git a/src/plugins/intel_cpu/src/node.cpp b/src/plugins/intel_cpu/src/node.cpp index e90b02eed7d..f6dc2284200 100644 --- a/src/plugins/intel_cpu/src/node.cpp +++ b/src/plugins/intel_cpu/src/node.cpp @@ -1673,7 +1673,7 @@ void Node::updateLastInputDims() { } for (size_t i = 0; i < lastInputDims.size(); i++) - lastInputDims[i] = getParentEdgesAtPort(i)[0]->getMemory().getStaticDims(); + lastInputDims[i] = getParentEdgesAtPort(i)[0]->getMemory().getDesc().getShape().getDims(); } bool Node::canFuseSimpleOperation(const NodePtr& node) const { diff --git a/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.cpp b/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.cpp new file mode 100644 index 00000000000..c7a266e4530 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.cpp @@ -0,0 +1,39 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "arbitrary_order_desc_creator.h" + +namespace ov { +namespace intel_cpu { + +ArbitraryOrderDescCreator::ArbitraryOrderDescCreator(VectorDims order) : + m_order(std::move(order)) { + OPENVINO_ASSERT(std::adjacent_find(m_order.begin(), m_order.end()) == m_order.end(), + "Can't construct ArbitraryOrderDescCreator, order vector contains repetitive elements", + vec2str(m_order)); +} + +CpuBlockedMemoryDesc +ArbitraryOrderDescCreator::createDesc(const ov::element::Type& precision, const Shape& srcShape) const { + auto&& dims = srcShape.getDims(); + OPENVINO_ASSERT(dims.size() == m_order.size(), + "Couldn't create a tensor descriptor, shape and order size mismatch. Shape: ", + vec2str(dims), + " order: ", + vec2str(m_order)); + + VectorDims blkDims(dims.size()); + for (size_t i = 0; i < dims.size(); ++i) { + blkDims[i] = dims[m_order[i]]; + } + + return CpuBlockedMemoryDesc(precision, srcShape, blkDims, m_order); +} + +size_t ArbitraryOrderDescCreator::getMinimalRank() const { + return m_order.size(); +} + +} // namespace intel_cpu +} // namespace ov \ No newline at end of file diff --git a/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.h b/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.h new file mode 100644 index 00000000000..591dad07968 --- /dev/null +++ b/src/plugins/intel_cpu/src/nodes/common/arbitrary_order_desc_creator.h @@ -0,0 +1,24 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "blocked_desc_creator.h" + +namespace ov { +namespace intel_cpu { + +class ArbitraryOrderDescCreator : public BlockedDescCreator { +public: + ArbitraryOrderDescCreator(VectorDims order); + + CpuBlockedMemoryDesc createDesc(const ov::element::Type& precision, const Shape& srcShape) const override; + size_t getMinimalRank() const override; + +private: + VectorDims m_order; +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/input.cpp b/src/plugins/intel_cpu/src/nodes/input.cpp index 6aa2b5d39d2..86f4d47a97f 100644 --- a/src/plugins/intel_cpu/src/nodes/input.cpp +++ b/src/plugins/intel_cpu/src/nodes/input.cpp @@ -395,6 +395,10 @@ Input::Input(const Shape& shape, const GraphContext::CPtr context) : Node(type, name, context) { constant = ConstantType::NoConst; + isDynamic = shape.isDynamic(); + if (isDynamic) { + shapeInference = PassThroughShapeInferFactory().makeShapeInfer(); + } if (getType() == Type::Input) { outputShapes.emplace_back(shape); addOriginalOutputPrecision(prc); diff --git a/src/plugins/intel_cpu/src/nodes/memory.cpp b/src/plugins/intel_cpu/src/nodes/memory.cpp index 14e2e7de0a5..f915703f071 100644 --- a/src/plugins/intel_cpu/src/nodes/memory.cpp +++ b/src/plugins/intel_cpu/src/nodes/memory.cpp @@ -11,6 +11,8 @@ #include "utils/general_utils.h" #include "memory_desc/dnnl_blocked_memory_desc.h" #include "utils/ngraph_utils.hpp" +#include "shape_inference/shape_inference_pass_through.hpp" +#include "common/arbitrary_order_desc_creator.h" using namespace dnnl; using namespace InferenceEngine; @@ -23,9 +25,9 @@ std::mutex MemoryNodeVirtualEdge::holderMutex; MemoryNode::MemoryNode(const std::shared_ptr& op) { if (auto assignOp = ov::as_type_ptr(op)) { - _id = assignOp->get_variable_id(); + m_id = assignOp->get_variable_id(); } else if (auto readValueOp = ov::as_type_ptr(op)) { - _id = readValueOp->get_variable_id(); + m_id = readValueOp->get_variable_id(); } else { OPENVINO_THROW("Unexpected ov::Node type: ", op->get_type_info().name, " in MemoryNode"); } @@ -61,7 +63,7 @@ MemoryOutput::~MemoryOutput() { MemoryNodeVirtualEdge::remove(this, holder); } -MemoryInput& MemoryOutput::getInputNode() { +MemoryInputBase& MemoryOutput::getInputNode() { OPENVINO_ASSERT(inputNode, "MemoryOutput ", getName(), " doesn't have sibling input"); return *inputNode; } @@ -196,19 +198,19 @@ void MemoryOutput::executeDynamicImpl(dnnl::stream strm) { execute(strm); } -void MemoryOutput::registerInputNode(MemoryInput* node) { +void MemoryOutput::registerInputNode(MemoryInputBase* node) { if (inputNode == node) { return; } if (inputNode) { inputNode->deregisterSibling(this); } inputNode = node; inputNode->registerOutputNode(this); } -void MemoryOutput::deregisterSibling(MemoryNode* node) { +void MemoryOutput::deregisterSibling(MemoryInputBase* node) { if (node == inputNode) { inputNode = nullptr; } } -bool MemoryInput::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { +bool MemoryInputBase::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { if (!one_of(op->get_type_info(), ov::op::v3::ReadValue::get_type_info_static(), @@ -222,8 +224,8 @@ bool MemoryInput::isSupportedOperation(const std::shared_ptr& op return true; } -MemoryInput::MemoryInput(const std::shared_ptr& op, const GraphContext::CPtr ctx) - : Input(op, ctx), MemoryNode(op) { +MemoryInputBase::MemoryInputBase(const std::shared_ptr& op, const GraphContext::CPtr ctx) + : Input(op, ctx), MemoryStateNode(op) { std::string errorMessage; if (!isSupportedOperation(op, errorMessage)) { OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage); @@ -233,7 +235,32 @@ MemoryInput::MemoryInput(const std::shared_ptr& op, const GraphContext } } -void MemoryInput::createPrimitive() { +MemoryInputBase::MemoryInputBase(const std::string id, + const std::string& name, + const std::string& type, + const Shape& output_shape, + const ov::element::Type& output_prc, + const GraphContext::CPtr context, + const ov::optional& input_shape, + const ov::optional& input_prc) : + Input(output_shape, output_prc, name, type, context), MemoryStateNode(id) { + outputShapes.emplace_back(output_shape); + addOriginalOutputPrecision(output_prc); + if (input_shape) { + inputShapes.push_back(*input_shape); + isDynamic = isDynamic || input_shape->isDynamic(); + if (isDynamic && !shapeInference) { + shapeInference = PassThroughShapeInferFactory().makeShapeInfer(); + } + } + if (input_prc) { + addOriginalInputPrecision(*input_prc); + } + // We don't need to use a virtual edge since this constructor is used in transformations and + // this is their responsibility to link the input/output nodes properly +} + +void MemoryInputBase::createPrimitive() { Input::createPrimitive(); if (!inputShapes.empty()) { auto parentEdge = getParentEdgeAt(0); @@ -244,63 +271,7 @@ void MemoryInput::createPrimitive() { } } -void MemoryInput::initSupportedPrimitiveDescriptors() { - if (!supportedPrimitiveDescriptors.empty()) - return; - - auto&& shape = getOutputShapeAtPort(0); - auto precision = getOriginalOutputPrecisionAtPort(0); - auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators(); - - NodeConfig config; - - if (!getParentEdges().empty()) { - PortConfig inPortConfig; - - inPortConfig.inPlace(-1); - inPortConfig.constant(false); - inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape)); - - config.inConfs.push_back(std::move(inPortConfig)); - } - - PortConfig outPortConfig; - - outPortConfig.inPlace(0); - outPortConfig.constant(false); - outPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape)); - - config.outConfs.push_back(std::move(outPortConfig)); - - supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown); -} - -void MemoryInput::initOptimalPrimitiveDescriptor() { - // Mimic the child node memory desc to avoid extra reorder - auto childEdge = getChildEdgeAt(0); - auto child = childEdge->getChild(); - auto childPd = child->getSelectedPrimitiveDescriptor(); - OPENVINO_ASSERT(childPd, - child->getTypeStr(), " ", - child->getName(), - "failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set"); - - const auto& childConfig = childPd->getConfig(); - auto mem_desc = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc(); - - auto selectedPd = getSelectedPrimitiveDescriptor(); - OPENVINO_ASSERT(selectedPd, - "MemoryInput ", - getName(), - " failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set"); - - auto config = selectedPd->getConfig(); - config.outConfs.front().setMemDesc(mem_desc); - //bypass any checks, we enforce the child descriptor - selectedPd->setConfig(config); -} - -void MemoryInput::resolveInPlaceEdges(Edge::LOOK look) { +void MemoryInputBase::resolveInPlaceEdges(Edge::LOOK look) { if (!(look & Edge::LOOK_UP)) { Node::resolveInPlaceEdges(look); return; @@ -324,17 +295,17 @@ void MemoryInput::resolveInPlaceEdges(Edge::LOOK look) { } } -MemoryInput::~MemoryInput() { +MemoryInputBase::~MemoryInputBase() { if (outputNode) { outputNode->deregisterSibling(this); } MemoryNodeVirtualEdge::remove(this, holder); } -MemoryOutput& MemoryInput::getOutputNode() { +MemoryOutput& MemoryInputBase::getOutputNode() { OPENVINO_ASSERT(outputNode, "MemoryOutput ", getName(), " doesn't have sibling input"); return *outputNode; } -void MemoryInput::assignState(MemStatePtr newState) { +void MemoryInputBase::assignState(MemStatePtr newState) { assignedMem = newState->input_mem(); OPENVINO_ASSERT(assignedMem, @@ -387,40 +358,18 @@ void MemoryInput::assignState(MemStatePtr newState) { getOutputNode().assignExtMemory(newState->output_mem(), newState->internal_desc()); } -MemStatePtr MemoryInput::makeState() const { - // assume ov::Tensor is always dense - auto original_desc = - std::make_shared(getOriginalOutputPrecisionAtPort(0), outputShapes.at(0)); - - auto mem_desc = getBaseMemDescAtOutputPort(0); - const auto& eng = getEngine(); - - auto state_name = getId(); - - // Remove suffix with pair ID. Internal information. - auto suffix_idx = state_name.find("/id="); - if (suffix_idx != std::string::npos) { - state_name = state_name.substr(0, suffix_idx); - } - - return std::make_shared(state_name, - [mem_desc, eng](){ return std::make_shared(eng, mem_desc); }, - original_desc, - getMemoryPtr()); -} - -void MemoryInput::registerOutputNode(MemoryOutput* node) { +void MemoryInputBase::registerOutputNode(MemoryOutput* node) { if (outputNode == node) { return; } if (outputNode) { outputNode->deregisterSibling(this); } outputNode = node; outputNode->registerInputNode(this); } -void MemoryInput::deregisterSibling(MemoryNode* node) { +void MemoryInputBase::deregisterSibling(MemoryOutput* node) { if (node == outputNode) { outputNode = nullptr; } } -MemoryNodeVirtualEdge::Holder* MemoryNodeVirtualEdge::registerInput(MemoryInput * node) { +MemoryNodeVirtualEdge::Holder* MemoryNodeVirtualEdge::registerInput(MemoryInputBase * node) { std::lock_guard lock{MemoryNodeVirtualEdge::holderMutex}; // in case of output already registered auto& holder = MemoryNodeVirtualEdge::getExisted(); @@ -441,7 +390,7 @@ MemoryNodeVirtualEdge::Holder* MemoryNodeVirtualEdge::registerOutput(MemoryOutpu auto& holder = MemoryNodeVirtualEdge::getExisted(); auto sibling = MemoryNodeVirtualEdge::getByName(holder, node->getId()); if (sibling != nullptr) { - auto inputNode = dynamic_cast(sibling); + auto inputNode = dynamic_cast(sibling); OPENVINO_ASSERT(inputNode != nullptr); node->registerInputNode(inputNode); } else { @@ -458,6 +407,210 @@ void MemoryNodeVirtualEdge::remove(MemoryNode * node, Holder* holder) { }); } } + +void MemoryInput::initSupportedPrimitiveDescriptors() { + if (!supportedPrimitiveDescriptors.empty()) + return; + + auto&& shape = getOutputShapeAtPort(0); + auto precision = getOriginalOutputPrecisionAtPort(0); + auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators(); + + NodeConfig config; + + if (!getParentEdges().empty()) { + PortConfig inPortConfig; + + inPortConfig.inPlace(-1); + inPortConfig.constant(false); + inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape)); + + config.inConfs.push_back(std::move(inPortConfig)); + } + + PortConfig outPortConfig; + + outPortConfig.inPlace(0); + outPortConfig.constant(false); + outPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape)); + + config.outConfs.push_back(std::move(outPortConfig)); + + supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown); +} + +void MemoryInput::initOptimalPrimitiveDescriptor() { + // Mimic the child node memory desc to avoid extra reorder + static const Type preferredTypes[] = { + Type::ScaledDotProductAttention, + Type::MatMul, + Type::FullyConnected, + Type::Convolution, + Type::RNNCell, + Type::RNNSeq, + Type::Subgraph + }; + + static const Type skipTypes[] = { + Type::ShapeOf + }; + + auto&& childEdges = getChildEdgesAtPort(0); + EdgePtr childEdge = childEdges.front(); + + if (childEdges.size() > 1) { + // try to prioritize memory desc + for (auto&& item : childEdges) { + auto itemType = item->getChild()->getType(); + if (std::any_of(std::begin(skipTypes), std::end(skipTypes), [=](Type type){ return type == itemType; })) { + continue; + } + if (std::any_of(std::begin(preferredTypes), + std::end(preferredTypes), [=](Type type){ return type == itemType; })) { + childEdge = item; + break; + } + } + } + + auto child = childEdge->getChild(); + auto childPd = child->getSelectedPrimitiveDescriptor(); + OPENVINO_ASSERT(childPd, + child->getTypeStr(), " ", + child->getName(), + "failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set"); + + const auto& childConfig = childPd->getConfig(); + auto mem_desc = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc(); + + auto selectedPd = getSelectedPrimitiveDescriptor(); + OPENVINO_ASSERT(selectedPd, + "MemoryInput ", + getName(), + " failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set"); + + auto config = selectedPd->getConfig(); + config.outConfs.front().setMemDesc(mem_desc); + //bypass any checks, we enforce the child descriptor + selectedPd->setConfig(config); +} + +MemStatePtr MemoryInput::makeState() const { + // assume ov::Tensor is always dense + auto original_desc = + std::make_shared(getOriginalOutputPrecisionAtPort(0), outputShapes.at(0)); + + auto mem_desc = getBaseMemDescAtOutputPort(0); + const auto& eng = getEngine(); + + auto state_name = getId(); + + // Remove suffix with pair ID. Internal information. + auto suffix_idx = state_name.find("/id="); + if (suffix_idx != std::string::npos) { + state_name = state_name.substr(0, suffix_idx); + } + + return std::make_shared(state_name, + std::make_shared(eng, mem_desc), + std::make_shared(eng, mem_desc), + original_desc, + getMemoryPtr()); +} + +bool MemoryInput::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { + return MemoryInputBase::isSupportedOperation(op, errorMessage); +} + +void MemoryInputSDPA::initSupportedPrimitiveDescriptors() { + if (!supportedPrimitiveDescriptors.empty()) + return; + + auto&& shape = getOutputShapeAtPort(0); + auto precision = getOriginalOutputPrecisionAtPort(0); + auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators(); + NodeConfig config; + if (!getParentEdges().empty()) { + PortConfig inPortConfig; + inPortConfig.inPlace(-1); + inPortConfig.constant(false); + inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape)); + config.inConfs.push_back(std::move(inPortConfig)); + } + + auto&& childEdges = getChildEdgesAtPort(0); + auto itr = std::find_if(childEdges.begin(), childEdges.end(), + [](const EdgePtr& edge){ return Type::ScaledDotProductAttention == edge->getChild()->getType(); }); + + OPENVINO_ASSERT(itr != childEdges.end(), "MemoryInputSDPA isn't attached to an SDPA node"); + auto SDPA = (*itr)->getChild(); + auto childPort = (*itr)->getOutputNum(); + + // Since this is a very specialized implementation, lets mimic SDPA precision and set cabd layout + precision = SDPA->getOriginalInputPrecisionAtPort(childPort); + ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3}); + + PortConfig outPortConfig; + outPortConfig.inPlace(0); + outPortConfig.constant(false); + outPortConfig.setMemDesc(cabdDescCreator.createSharedDesc(precision, shape)); + config.outConfs.push_back(std::move(outPortConfig)); + supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown); +} + +void MemoryInputSDPA::initOptimalPrimitiveDescriptor() { + auto&& childEdges = getChildEdgesAtPort(0); + auto itr = std::find_if(childEdges.begin(), childEdges.end(), + [](const EdgePtr& edge){ return Type::ScaledDotProductAttention == edge->getChild()->getType(); }); + + OPENVINO_ASSERT(itr != childEdges.end(), "MemoryInputSDPA isn't attached to an SDPA node"); + auto childEdge = *itr; + auto child = childEdge->getChild(); + auto childPd = child->getSelectedPrimitiveDescriptor(); + OPENVINO_ASSERT(childPd, + child->getTypeStr(), " ", + child->getName(), + "failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set"); + + const auto& childConfig = childPd->getConfig(); + auto childPrecision = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc()->getPrecision(); + + auto selectedPd = getSelectedPrimitiveDescriptor(); + OPENVINO_ASSERT(selectedPd, + "MemoryInputSDPA ", + getName(), + " failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set"); + + auto config = selectedPd->getConfig(); + auto memDesc = config.outConfs.front().getMemDesc(); + auto newMemDesc = memDesc->cloneWithNewPrecision(childPrecision); + config.outConfs.front().setMemDesc(newMemDesc); + //bypass any checks, we enforce the child descriptor precision + selectedPd->setConfig(config); +} + +MemStatePtr MemoryInputSDPA::makeState() const { + // assume ov::Tensor is always dense + auto original_desc = + std::make_shared(getOriginalOutputPrecisionAtPort(0), outputShapes.at(0)); + + auto mem_desc = getBaseMemDescAtOutputPort(0); + const auto& eng = getEngine(); + + auto state_name = getId(); + + // Remove suffix with pair ID. Internal information. + auto suffix_idx = state_name.find("/id="); + if (suffix_idx != std::string::npos) { + state_name = state_name.substr(0, suffix_idx); + } + + return std::make_shared(state_name, + std::make_shared(eng, mem_desc, std::make_shared(make_unique())), + original_desc, + getMemoryPtr()); +} + } // namespace node } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/memory.hpp b/src/plugins/intel_cpu/src/nodes/memory.hpp index 35d919851f0..5237e5576f5 100644 --- a/src/plugins/intel_cpu/src/nodes/memory.hpp +++ b/src/plugins/intel_cpu/src/nodes/memory.hpp @@ -9,6 +9,7 @@ #include "ie_algorithm.hpp" #include "input.h" #include +#include #include #include #include @@ -20,20 +21,24 @@ namespace intel_cpu { namespace node { class MemoryOutput; -class MemoryInput; +class MemoryInputBase; -class MemoryNode { //TODO , segregate interfaces - std::string _id; +class MemoryNode { public: - explicit MemoryNode(std::string id) : _id(id) {} + explicit MemoryNode(std::string id) : m_id(id) {} explicit MemoryNode(const std::shared_ptr& op); virtual ~MemoryNode() = default; std::string getId() const { - return _id; + return m_id; } - virtual void registerInputNode(MemoryInput*) = 0; - virtual void registerOutputNode(MemoryOutput*) = 0; - virtual void deregisterSibling(MemoryNode*) = 0; + +private: + std::string m_id; +}; + +class MemoryStateNode : public MemoryNode { +public: + using MemoryNode::MemoryNode; virtual void assignState(MemStatePtr newState) = 0; virtual MemStatePtr makeState() const = 0; }; @@ -60,7 +65,7 @@ public: } static Holder* registerOutput(MemoryOutput * node); - static Holder* registerInput(MemoryInput * node); + static Holder* registerInput(MemoryInputBase * node); static void remove(MemoryNode * node, Holder* holder); static std::mutex holderMutex; }; @@ -81,12 +86,8 @@ public: } void resolveInPlaceEdges(Edge::LOOK look) override; - void registerInputNode(MemoryInput* node) override; - void registerOutputNode(MemoryOutput* node) override { - OPENVINO_THROW("MemoryOutput node has no MemoryOutput type sibling!"); - } - - void deregisterSibling(MemoryNode* node) override; + void registerInputNode(MemoryInputBase* node); + void deregisterSibling(MemoryInputBase* node); bool needShapeInfer() const override { return false; } bool needPrepareParams() const override { return false; } @@ -94,38 +95,38 @@ public: void assignExtMemory(const MemoryPtr& mem, const MemoryDescPtr& memDesc); private: - MemoryInput& getInputNode(); - void assignState(MemStatePtr newState) override { - OPENVINO_THROW("Unexpected MemoryOutput::assignState call"); //TODO , segregate interfaces - } - MemStatePtr makeState() const override { - OPENVINO_THROW("Unexpected MemoryOutput::makeState call"); //TODO , segregate interfaces - } + MemoryInputBase& getInputNode(); private: /** * @brief keeps reference to input sibling node */ - MemoryInput* inputNode = nullptr; + MemoryInputBase* inputNode = nullptr; MemoryPtr assignedMem = nullptr; MemoryDescPtr extMemDesc = nullptr; // used for resize MemoryNodeVirtualEdge::Holder* holder = nullptr; ProxyMemoryMngrPtr memMngr = nullptr; }; -class MemoryInput : public Input, public MemoryNode { +class MemoryInputBase : public Input, public MemoryStateNode { public: - MemoryInput(const std::shared_ptr& op, const GraphContext::CPtr context); - ~MemoryInput() override; + MemoryInputBase(const std::shared_ptr& op, const GraphContext::CPtr context); + MemoryInputBase(const std::string id, + const std::string& name, + const std::string& type, + const Shape& output_shape, + const ov::element::Type& output_prc, + const GraphContext::CPtr context, + const ov::optional& input_shape, + const ov::optional& input_prc); + + ~MemoryInputBase() override; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; bool created() const override { return getType() == Type::MemoryInput; } - void initSupportedPrimitiveDescriptors() override; - void initOptimalPrimitiveDescriptor() override; - void execute(dnnl::stream strm) override {/*pass*/} void executeDynamicImpl(dnnl::stream strm) override {/*pass*/} @@ -133,17 +134,11 @@ public: void resolveInPlaceEdges(Edge::LOOK look) override; - void registerInputNode(MemoryInput* node) override { - OPENVINO_THROW("MemoryInput node has no MemoryInput type sibling!"); - } - - void registerOutputNode(MemoryOutput* node) override; - void deregisterSibling(MemoryNode* node) override; + void registerOutputNode(MemoryOutput* node); + void deregisterSibling(MemoryOutput* node); + // May be extracted to some interface when necessary void assignState(MemStatePtr newState) override; - MemStatePtr makeState() const override; - -private: MemoryOutput& getOutputNode(); private: @@ -156,6 +151,27 @@ private: ProxyMemoryMngrPtr memMngr = nullptr; }; +class MemoryInput : public MemoryInputBase { +public: + using MemoryInputBase::MemoryInputBase; + static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; + + void initSupportedPrimitiveDescriptors() override; + void initOptimalPrimitiveDescriptor() override; + + MemStatePtr makeState() const override; +}; + +class MemoryInputSDPA : public MemoryInputBase { +public: + using MemoryInputBase::MemoryInputBase; + static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; + + void initSupportedPrimitiveDescriptors() override; + void initOptimalPrimitiveDescriptor() override; + + MemStatePtr makeState() const override; +}; } // namespace node } // namespace intel_cpu } // namespace ov diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp index ce5daa38082..2ba51286e79 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.cpp @@ -15,11 +15,12 @@ #include #include -#include "common/cpu_memcpy.h" +#include "openvino/core/parallel.hpp" #include "memory_desc/cpu_memory_desc_utils.h" #include "memory_desc/dnnl_blocked_memory_desc.h" #include "utils/plain_tensor.hpp" #include +#include "common/arbitrary_order_desc_creator.h" #ifdef OV_CPU_WITH_MLAS # include "mlas/sgemm.hpp" @@ -576,48 +577,93 @@ struct MHASingleToken { template struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAttention::Executor { - PlainTensor q_input; // f32[B, L1, H*S] / [B, H, L1, S] - PlainTensor k_input; // f32[B, L1, H*S] - PlainTensor v_input; // f32[B, L1, H*S] - PlainTensor k_cache; // f32[B, H, max_kvLen, S] - PlainTensor v_cache; // f32[B, H, max_kvLen, S] + PlainTensor q_input; // f32[B, H, L1, S] + PlainTensor k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] + PlainTensor v_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S] PlainTensor beam_table; // i32[B, max_kvLen] - PlainTensor attn_mask; // f32[B, qLen + kvLen] - float scale_input = 0.0f; // f32[B, qLen + kvLen] - PlainTensor cos_tab; // f32[max_kv_len, rotary_dims//2] - PlainTensor sin_tab; // f32[max_kv_len, rotary_dims//2] - - PlainTensor output_emb; // f32[B, L1, H*S] + PlainTensor attn_buf; // f32[[B|1],[H|1], L1|1, L0+L1] + float scale_input = 0.0f; MHAKernel kernel; MHASingleToken kernel_single_token; - PlainTensor m_query_emb; // query with RoPE position embedding + size_t B, H, L1, L0, S; - ScaledDotProductAttention::Config config; - AttentionExecutor(const ScaledDotProductAttention::Config& _config) : config(_config) {} + Config config; + AttentionExecutor(const Config& _config) : attn_buf(true), config(_config) {} void prepare_attn_mask(MemoryPtr attn_input) { - attn_mask.resize(attn_input->getStaticDims()); + attn_buf.resize(attn_input->getStaticDims()); auto p = reinterpret_cast(attn_input->getData()); for (size_t i = 0; i < attn_input->getSize(); i++) - attn_mask.data()[i] = p[i] ? 0.0f : -FLT_MAX; + attn_buf.data()[i] = p[i] ? 0.0f : -FLT_MAX; + } + + void concat_pastkv(const std::vector& inputs, + const std::vector& outputs, + const PlainTensor& k_input, + const PlainTensor& v_input, + PlainTensor& past_k_output, + PlainTensor& past_v_output) { + if (config.config.fuse_concat) { + k_input.assert_dims({B, 0, L1, S}, true); + v_input.assert_dims({B, 0, L1, S}, true); + auto past_k_idx = inputs.size() - 2; + auto past_k_mem = inputs[past_k_idx + 0]; + L0 = past_k_mem->getStaticDims()[2]; + // k,v may support multiquery + auto Hk = past_k_mem->getStaticDims()[1]; + // [B, H, L0, S] + past_k_output.reset(outputs[1]); + past_v_output.reset(outputs[2]); + parallel_for3d(B, Hk, L1, [&](size_t b, size_t h, size_t m) { + std::memcpy(&past_k_output.at({b, h, m + L0, 0}), + &k_input.at({b, h, m, 0}), + S * sizeof(T)); + std::memcpy(&past_v_output.at({b, h, m + L0, 0}), + &v_input.at({b, h, m, 0}), + S * sizeof(T)); + }); + if (!config.is_concat_inplaced) { + PlainTensor past_k_input, past_v_input; + past_k_input.reset(past_k_mem); + past_v_input.reset(inputs[past_k_idx + 1]); + parallel_for3d(B, Hk, L0, [&](size_t b, size_t h, size_t m) { + std::memcpy(&past_k_output.at({b, h, m, 0}), + &past_k_input.at({b, h, m, 0}), + S * sizeof(T)); + std::memcpy(&past_v_output.at({b, h, m, 0}), + &past_v_input.at({b, h, m, 0}), + S * sizeof(T)); + }); + } + } else { + // k,v inputs are already concatenated + L0 = k_input.size(2) - L1; + k_input.assert_dims({B, 0, L0 + L1, S}, true); + v_input.assert_dims({B, 0, L0 + L1, S}, true); + past_k_output = k_input; + past_v_output = v_input; + } } void execute(dnnl::stream strm, const std::vector& inputs, const std::vector& outputs) override { - bool has_out_transpose = config.output_BLHxS; - bool fuse_causal_attn = config.fuse_causal_attn; - bool is_causal = config.is_causal; - auto input_num = inputs.size(); + bool has_out_transpose = config.config.output_BLHxS; + bool fuse_causal_attn = config.config.fuse_causal_attn; + bool is_causal = config.config.is_causal; + const bool fuse_concat = config.config.fuse_concat; + auto input_num = inputs.size() - (fuse_concat ? 2 : 0); q_input.reset(inputs[0]); k_input.reset(inputs[1]); v_input.reset(inputs[2]); + PlainTensor attn_mask; if (input_num > 3) { // attn_mask if (inputs[3]->getDesc().getPrecision() == ov::element::u8) { // bool->f32 prepare_attn_mask(inputs[3]); + attn_mask = attn_buf; } else { attn_mask.reset(inputs[3]); } @@ -627,30 +673,22 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt } } - size_t B, H, L1, L0, S; - - // q, k, v: [B, H, L1, S] + // q: [B, H, L1, S] B = q_input.size(0); H = q_input.size(1); L1 = q_input.size(2); - L0 = k_input.size(2) - L1; S = q_input.size(-1); - ov::intel_cpu::PlainTensor output_emb(outputs[0]); PlainTensor present_key, present_value; + concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value); - q_input.assert_dims({B, H, L1, S}); - k_input.assert_dims({B, 0, L0 + L1, S}, true); - v_input.assert_dims({B, 0, L0 + L1, S}, true); - m_query_emb = q_input; - present_key = k_input; - present_value = v_input; + ov::intel_cpu::PlainTensor output_emb(outputs[0]); bool auto_causal; bool use_attn_mask; if (fuse_causal_attn) { assert(attn_mask); - attn_mask.assert_dims({B, 1, 1, L0 + L1}); + attn_mask.assert_dims({B, 1, L1, L0 + L1}); auto_causal = true; use_attn_mask = true; } else { @@ -677,7 +715,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt if (L1 > 1) { // multi-token version - kernel(strm, m_query_emb, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor(), + kernel(strm, q_input, k_input, v_input, {}, use_attn_mask ? attn_mask : PlainTensor(), output_emb, has_out_transpose, auto_causal, scale_input); } else { // 1-token version @@ -685,7 +723,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt // 1, in matrix mutiply, using AMX is not efficency because the M dimension of A will alway be 1 // 2, using float will save the repack cost which typically is required for bf16/int8 opt // 3, using dot product can leverage the SIMD while easily adapt to indirect kv cache - kernel_single_token(m_query_emb, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor(), + kernel_single_token(q_input, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor(), output_emb, beam_table, has_out_transpose, auto_causal, scale_input); } } @@ -700,9 +738,10 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr(op); if (node) { - m_config.is_causal = node->get_causal(); + m_config.config.is_causal = node->get_causal(); } else { - OPENVINO_THROW("CPU: cast to v13::ScaledDotProductAttention failed."); + const auto node = std::dynamic_pointer_cast(op); + m_config.config = node->get_config(); } } @@ -711,6 +750,80 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { return; auto rtPrecision = getOriginalInputPrecisionAtPort(0); + NodeConfig config; + auto& creatorsMap = BlockedDescCreator::getCommonCreators(); + auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 2 : 0); + config.inConfs.resize(getOriginalInputsNumber()); + config.outConfs.resize(getOriginalOutputsNumber()); + config.inConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getInputShapeAtPort(0))); + config.inConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getInputShapeAtPort(1))); + config.inConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getInputShapeAtPort(2))); + auto nextPortIdx = 3; + if (orginSDPInputNumber > 3) { + // attn_mask + if (getOriginalInputPrecisionAtPort(nextPortIdx) == ov::element::u8) { + config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::u8, getInputShapeAtPort(nextPortIdx))); + } else { + config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::f32, getInputShapeAtPort(nextPortIdx))); + } + nextPortIdx++; + } + if (orginSDPInputNumber > 4) { + config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + ov::element::f32, getInputShapeAtPort(nextPortIdx))); + } + + if (m_config.config.fuse_concat) { + ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3}); + + config.inConfs[orginSDPInputNumber + 0].setMemDesc(cabdDescCreator.createSharedDesc( + rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 0))); + config.inConfs[orginSDPInputNumber + 1].setMemDesc(cabdDescCreator.createSharedDesc( + rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 1))); + + config.outConfs[1].setMemDesc(cabdDescCreator.createSharedDesc( + rtPrecision, getOutputShapeAtPort(1))); + config.outConfs[1].inPlace(orginSDPInputNumber + 0); + config.outConfs[2].setMemDesc(cabdDescCreator.createSharedDesc( + rtPrecision, getOutputShapeAtPort(2))); + config.outConfs[2].inPlace(orginSDPInputNumber + 1); + } + + config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getOutputShapeAtPort(0))); + + supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any); + // may fallback to abcd without inplace + if (m_config.config.fuse_concat) { + config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 0))); + config.inConfs[orginSDPInputNumber + 1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 1))); + config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getOutputShapeAtPort(1))); + config.outConfs[1].inPlace(-1); + config.outConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc( + rtPrecision, getOutputShapeAtPort(2))); + config.outConfs[2].inPlace(-1); + supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any); + } +} + +void ScaledDotProductAttention::createPrimitive() { + if (m_config.config.fuse_concat) { + auto desc = getSelectedPrimitiveDescriptor(); + if (desc == nullptr) + OPENVINO_THROW("has unidentified preferable primitive descriptor"); + + m_config.is_concat_inplaced = desc->getConfig().outConfs[1].inPlace() >= 0; + } + auto rtPrecision = getOriginalInputPrecisionAtPort(0); + if (rtPrecision == ov::element::bf16) { m_executor = std::make_shared>(m_config); } else { @@ -722,29 +835,6 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() { m_executor = std::make_shared>(m_config); #endif } - - // initialize input ports - std::vector inPortConfigs; - inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(0), false, -1); - inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(1), false, -1); - inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(2), false, -1); - if (getOriginalInputsNumber() > 3) { - // attn_mask - if (getOriginalInputPrecisionAtPort(3) == ov::element::u8) { - inPortConfigs.emplace_back(LayoutType::ncsp, ov::element::u8, getInputShapeAtPort(3), false, -1); - } else { - inPortConfigs.emplace_back(LayoutType::ncsp, ov::element::f32, getInputShapeAtPort(3), false, -1); - } - } - if (getOriginalInputsNumber() > 4) { - inPortConfigs.emplace_back(LayoutType::ncsp, ov::element::f32, getInputShapeAtPort(4), false, -1); - } - - // initialize output port - std::vector outPortConfigs; - outPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getOutputShapeAtPort(0), false, -1); - - addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any); } void ScaledDotProductAttention::execute(dnnl::stream strm) { @@ -760,8 +850,9 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) { bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept { try { - if (!std::dynamic_pointer_cast(op)) { - errorMessage = "Only ScaledDotProductAttention operation are supported"; + if (!std::dynamic_pointer_cast(op) && + !std::dynamic_pointer_cast(op)) { + errorMessage = "Only ScaledDotProductAttention or ScaledDotProductAttentionStub operation are supported"; return false; } // expect shape of q: [B, H, L, S] @@ -770,7 +861,14 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptrget_input_size() > 3) { + int orgSDPAInput = static_cast(op->get_input_size()); + const auto node = std::dynamic_pointer_cast(op); + if (node) { + if (node->get_config().fuse_concat) { + orgSDPAInput -= 2; + } + } + if (orgSDPAInput > 3) { inRank = op->get_input_partial_shape(3).size(); if (inRank > 4u) { errorMessage = "Doesn't support 'attention mask' with rank: " + std::to_string(inRank); diff --git a/src/plugins/intel_cpu/src/nodes/scaled_attn.h b/src/plugins/intel_cpu/src/nodes/scaled_attn.h index 67da5bde580..7c08ef99faf 100644 --- a/src/plugins/intel_cpu/src/nodes/scaled_attn.h +++ b/src/plugins/intel_cpu/src/nodes/scaled_attn.h @@ -10,6 +10,8 @@ #include #include +#include "transformations/cpu_opset/common/op/sdp.hpp" + namespace ov { namespace intel_cpu { namespace node { @@ -22,6 +24,10 @@ public: bool created() const override { return getType() == Type::ScaledDotProductAttention; } + // pastkv may have zero dimension + bool isExecutable() const override { + return true; + } bool needPrepareParams() const override { return false; } @@ -30,6 +36,7 @@ public: } void initSupportedPrimitiveDescriptors() override; void execute(dnnl::stream strm) override; + void createPrimitive() override; static bool isSupportedOperation(const std::shared_ptr& op, std::string& errorMessage) noexcept; enum KernelTypes { KT_REF, KT_ONEDNN, KT_MLAS}; @@ -40,9 +47,8 @@ private: }; struct Config { - bool output_BLHxS = false; - bool fuse_causal_attn = false; - bool is_causal = false; + ScaledDotProductAttentionStub::Config config; + bool is_concat_inplaced = false; }; Config m_config; diff --git a/src/plugins/intel_cpu/src/nodes/shapeof.cpp b/src/plugins/intel_cpu/src/nodes/shapeof.cpp index e287cde7252..762c4e0d982 100644 --- a/src/plugins/intel_cpu/src/nodes/shapeof.cpp +++ b/src/plugins/intel_cpu/src/nodes/shapeof.cpp @@ -51,12 +51,34 @@ void ShapeOf::initSupportedPrimitiveDescriptors() { ov::element::Type precision = getOriginalInputPrecisionAtPort(0); - const LayoutType dataFormats[4] = { LayoutType::ncsp, LayoutType::nspc, LayoutType::nCsp16c, LayoutType::nCsp8c }; - for (const auto &df : dataFormats) { - addSupportedPrimDesc({{df, precision}}, - {{LayoutType::ncsp, ov::element::i32}}, - impl_desc_type::ref); - } + addSupportedPrimDesc({{LayoutType::ncsp, precision}}, + {{LayoutType::ncsp, ov::element::i32}}, + impl_desc_type::ref); +} + +void ShapeOf::initOptimalPrimitiveDescriptor() { + // Mimic the parent node memory desc to avoid extra reorder + auto parentEdge = getParentEdgeAt(0); + auto parent = parentEdge->getParent(); + auto parentPd = parent->getSelectedPrimitiveDescriptor(); + OPENVINO_ASSERT(parentPd, + parent->getTypeStr(), " ", + parent->getName(), + "failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set"); + + const auto& parentConfig = parentPd->getConfig(); + auto mem_desc = parentConfig.outConfs[parentEdge->getInputNum()].getMemDesc(); + + auto selected_pd = getSelectedPrimitiveDescriptor(); + OPENVINO_ASSERT(selected_pd, + "ShapeOf ", + getName(), + " failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set"); + + auto config = selected_pd->getConfig(); + config.inConfs.front().setMemDesc(mem_desc); + //bypass any checks, we enforce the parent descriptor + selected_pd->setConfig(config); } bool ShapeOf::isExecutable() const { @@ -66,12 +88,12 @@ bool ShapeOf::isExecutable() const { void ShapeOf::execute(dnnl::stream strm) { auto inPtr = getParentEdgeAt(0)->getMemoryPtr(); auto outPtr = getChildEdgeAt(0)->getMemoryPtr(); - auto inDims = inPtr->getStaticDims(); + auto&& inDims = inPtr->getStaticDims(); size_t dimsCount = inDims.size(); if (outPtr->getStaticDims().size() != 1 || dimsCount != outPtr->getStaticDims()[0]) OPENVINO_THROW(errorPrefix, "has inconsistent input shape and output size"); - auto *dst = reinterpret_cast(getChildEdgeAt(0)->getMemoryPtr()->getData()); + auto* dst = reinterpret_cast(outPtr->getData()); for (size_t i = 0; i < dimsCount; i++) { dst[i] = inDims[i]; diff --git a/src/plugins/intel_cpu/src/nodes/shapeof.h b/src/plugins/intel_cpu/src/nodes/shapeof.h index 0955e6f172d..7841299adfc 100644 --- a/src/plugins/intel_cpu/src/nodes/shapeof.h +++ b/src/plugins/intel_cpu/src/nodes/shapeof.h @@ -20,6 +20,7 @@ public: void getSupportedDescriptors() override; void initSupportedPrimitiveDescriptors() override; + void initOptimalPrimitiveDescriptor() override; void execute(dnnl::stream strm) override; bool created() const override; bool needPrepareParams() const override {return false;}; diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.cpp new file mode 100644 index 00000000000..e433d5ad34f --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.cpp @@ -0,0 +1,57 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "sdp.hpp" + +#include + +#include "transformations/itt.hpp" + +ov::intel_cpu::ScaledDotProductAttentionStub::ScaledDotProductAttentionStub(const OutputVector& args, const Config& cfg) + : Op(args), + m_config(cfg) { + constructor_validate_and_infer_types(); +} + +std::shared_ptr ov::intel_cpu::ScaledDotProductAttentionStub::clone_with_new_inputs( + const ov::OutputVector& new_args) const { + INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_with_new_inputs); + check_new_args_count(this, new_args); + return std::make_shared(new_args, m_config); +} + +void ov::intel_cpu::ScaledDotProductAttentionStub::validate_and_infer_types() { + INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_validate_and_infer_types); + auto input_num = get_input_size(); + // [B, H, L1, S] + auto q_ps = get_input_partial_shape(0); + // [B, H, L0, S] + auto past_kv_ps = get_input_partial_shape(input_num - 1); + + NODE_VALIDATION_CHECK(this, m_config.output_BLHxS == false); + NODE_VALIDATION_CHECK(this, q_ps.size() >= 3); + if (past_kv_ps.rank().is_static()) { + NODE_VALIDATION_CHECK(this, q_ps.size() == past_kv_ps.size()); + for (size_t i = 0; i < q_ps.size(); i++) { + if (i == q_ps.size() - 2) + continue; + NODE_VALIDATION_CHECK(this, q_ps[i].compatible(past_kv_ps[i])); + } + past_kv_ps[q_ps.size() - 2] += q_ps[q_ps.size() - 2]; + } + set_output_type(0, get_input_element_type(0), q_ps); + set_output_type(1, get_input_element_type(input_num - 1), past_kv_ps); + set_output_type(2, get_input_element_type(input_num - 1), past_kv_ps); +} + +bool ov::intel_cpu::ScaledDotProductAttentionStub::visit_attributes(ov::AttributeVisitor& visitor) { + INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_visit_attributes); + visitor.start_structure("config"); + visitor.on_attribute("output_BLHxS", m_config.output_BLHxS); + visitor.on_attribute("fuse_causal_attn", m_config.fuse_causal_attn); + visitor.on_attribute("is_causal", m_config.is_causal); + visitor.on_attribute("fuse_concat", m_config.fuse_concat); + visitor.finish_structure(); + return true; +} diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.hpp new file mode 100644 index 00000000000..7cf45b24bd7 --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/op/sdp.hpp @@ -0,0 +1,50 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +#include "openvino/op/op.hpp" + +namespace ov { +namespace intel_cpu { +/// \brief Scaled dot product attention from PyTorch, fused with Concat +/// +/// \ingroup ov_ops_cpp_api + +class ScaledDotProductAttentionStub : public ov::op::Op { +public: + OPENVINO_OP("ScaledDotProductAttentionStub", "cpu_plugin_opset"); + + ScaledDotProductAttentionStub() = default; + + struct Config { + bool output_BLHxS = false; // true implies that output is [B,L,H*S] + + bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask + bool is_causal = false; // apply causal mask internally + bool fuse_concat = false; // fuse (concat->sdp) ==> sdp + }; + + ScaledDotProductAttentionStub(const OutputVector& args, const Config& cfg); + + std::shared_ptr clone_with_new_inputs(const OutputVector& new_args) const override; + bool visit_attributes(AttributeVisitor& visitor) override; + void validate_and_infer_types() override; + + const Config& get_config() const { + return m_config; + } + + Config& get_config() { + return m_config; + } + +private: + Config m_config; +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.cpp new file mode 100644 index 00000000000..a1f9dd24ddc --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.cpp @@ -0,0 +1,121 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "stateful_sdp_fusion.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "itt.hpp" +#include "ov_ops/type_relaxed.hpp" +#include "transformations/cpu_opset/common/op/sdp.hpp" + +namespace ov { +namespace intel_cpu { + +StatefulSDPFusion::StatefulSDPFusion() { + MATCHER_SCOPE(StatefulSDPFusion); + using namespace ov::pass::pattern; + + auto past_k = wrap_type(); + auto past_v = wrap_type(); + auto convert_past_k = wrap_type({past_k}); + auto convert_past_v = wrap_type({past_v}); + auto concat_input_k = std::make_shared(OutputVector{past_k, convert_past_k}); + auto concat_input_v = std::make_shared(OutputVector{past_v, convert_past_v}); + auto concat_k = wrap_type({concat_input_k, any_input()}); + auto concat_v = wrap_type({concat_input_v, any_input()}); + auto sdp0 = wrap_type({any_input(), concat_k, concat_v}); + auto sdp1 = wrap_type({any_input(), concat_k, concat_v, any_input()}); + auto sdp2 = wrap_type({any_input(), concat_k, concat_v, any_input(), any_input()}); + auto sdp = std::make_shared(OutputVector{sdp0, sdp1, sdp2}); + + ov::matcher_pass_callback callback = [=](Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + + auto find_assign = [&](const ov::Output& out, opset6::Assign*& assign, opset1::Convert*& cvt) { + auto present_to = out.get_target_inputs(); + if (present_to.size() != 2) + return; + for (auto& to : present_to) { + auto to_node = to.get_node(); + if (auto convert = dynamic_cast(to_node)) { + auto cvt_targets = convert->get_output_target_inputs(0); + if (cvt_targets.size() == 1) { + to_node = cvt_targets.begin()->get_node(); + cvt = convert; + } + } + assign = dynamic_cast(to_node); + if (assign) + return; + } + }; + + std::shared_ptr read_cvt_k_node, read_cvt_v_node; + const auto sdp_node = ov::as_type_ptr(root); + const auto past_k_node = ov::as_type_ptr(pattern_map.at(past_k).get_node_shared_ptr()); + const auto past_v_node = ov::as_type_ptr(pattern_map.at(past_v).get_node_shared_ptr()); + const auto concat_k_node = ov::as_type_ptr(pattern_map.at(concat_k).get_node_shared_ptr()); + const auto concat_v_node = ov::as_type_ptr(pattern_map.at(concat_v).get_node_shared_ptr()); + if (pattern_map.count(convert_past_k)) { + read_cvt_k_node = ov::as_type_ptr(pattern_map.at(convert_past_k).get_node_shared_ptr()); + read_cvt_v_node = ov::as_type_ptr(pattern_map.at(convert_past_v).get_node_shared_ptr()); + } + opset6::Assign* assign_k_node = nullptr, *assign_v_node = nullptr; + opset1::Convert* assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr; + find_assign(concat_k_node, assign_k_node, assign_cvt_k_node); + if (!assign_k_node) + return false; + if (past_k_node->get_variable_id() != assign_k_node->get_variable_id()) + return false; + + find_assign(concat_v_node, assign_v_node, assign_cvt_v_node); + if (!assign_v_node) + return false; + if (past_v_node->get_variable_id() != assign_v_node->get_variable_id()) + return false; + + auto args = sdp_node->input_values(); + args[1] = concat_k_node->input_value(1); + args[2] = concat_v_node->input_value(1); + args.push_back(read_cvt_k_node ? read_cvt_k_node->output(0) : past_k_node->output(0)); + args.push_back(read_cvt_v_node ? read_cvt_v_node->output(0) : past_v_node->output(0)); + ov::intel_cpu::ScaledDotProductAttentionStub::Config config; + + config.is_causal = sdp_node->get_causal(); + config.fuse_concat = true; + + auto old_node = sdp_node; + auto new_node = std::make_shared(args, config); + new_node->set_friendly_name(old_node->get_friendly_name()); + ov::replace_node(old_node, {new_node->output(0)}); + if (assign_cvt_k_node) + assign_cvt_k_node->set_arguments({new_node->output(1)}); + else + assign_k_node->set_arguments({new_node->output(1)}); + + if (assign_cvt_v_node) + assign_cvt_v_node->set_arguments({new_node->output(2)}); + else + assign_v_node->set_arguments({new_node->output(2)}); + + return true; + }; + + auto m = std::make_shared(sdp, matcher_name); + this->register_matcher(m, callback); +} + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp new file mode 100644 index 00000000000..21f52508681 --- /dev/null +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp @@ -0,0 +1,18 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace intel_cpu { +class StatefulSDPFusion : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("StatefulSDPFusion", "0"); + StatefulSDPFusion(); +}; + +} // namespace intel_cpu +} // namespace ov diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index a034acb2572..ebc3bc61bff 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -112,6 +112,7 @@ #include "transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp" #include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp" #include "transformations/cpu_opset/common/pass/rope_fusion.hpp" +#include "transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp" // Snippets #include "snippets/pass/tokenization.hpp" @@ -660,6 +661,7 @@ void Transformations::PostLpt() { CPU_REGISTER_PASS_X64(postLPTPassManager, EliminateStridedSlice); CPU_REGISTER_PASS_X64(postLPTPassManager, RoPEFusion); + CPU_REGISTER_PASS_X64(postLPTPassManager, StatefulSDPFusion); postLPTPassManager.run_passes(model); } diff --git a/src/plugins/intel_cpu/src/utils/plain_tensor.hpp b/src/plugins/intel_cpu/src/utils/plain_tensor.hpp index 5c2202f2ba1..f6441e18a68 100644 --- a/src/plugins/intel_cpu/src/utils/plain_tensor.hpp +++ b/src/plugins/intel_cpu/src/utils/plain_tensor.hpp @@ -169,10 +169,19 @@ struct PlainTensor { } void reset(MemoryPtr mem) { - assert_dt
(mem->getDesc().getPrecision()); + const auto& mem_desc = mem->getDesc(); + assert_dt
(mem_desc.getPrecision()); + const auto* desc_ptr = mem_desc.as(); + // not support block layout + OPENVINO_ASSERT(desc_ptr && desc_ptr->getOrder().size() == mem->getStaticDims().size()); m_mem = mem; + VectorDims strides(desc_ptr->getStrides().size()); + const auto& orders = desc_ptr->getOrder(); + for (size_t i = 0; i < orders.size(); i++) { + strides[orders[i]] = desc_ptr->getStrides()[i]; + } // this reshape_to() can do reshape w/o additional cost - resize(mem->getStaticDims(), reinterpret_cast(mem->getData())); + resize(mem->getStaticDims(), reinterpret_cast(mem->getData()), &strides); } ov::element::Type get_precision(void) { @@ -327,14 +336,14 @@ struct PlainTensor { return new_tensor_view; } - void resize(const VectorDims& new_dims, DT* data = nullptr) { + void resize(const VectorDims& new_dims, DT* data = nullptr, const VectorDims* strides = nullptr) { // initialize strides for compact/dense tensor m_rank = new_dims.size(); assert(m_rank <= PLAINTENSOR_RANK_MAX); size_t stride = 1; for (int i = m_rank - 1; i >= 0; i--) { m_dims[i] = new_dims[i]; - m_strides[i] = stride; + m_strides[i] = strides ? (*strides)[i] : stride; stride *= new_dims[i]; } diff --git a/src/plugins/intel_cpu/tests/functional/single_layer_tests/shapeof.cpp b/src/plugins/intel_cpu/tests/functional/single_layer_tests/shapeof.cpp index 93e5f614894..3daef205663 100644 --- a/src/plugins/intel_cpu/tests/functional/single_layer_tests/shapeof.cpp +++ b/src/plugins/intel_cpu/tests/functional/single_layer_tests/shapeof.cpp @@ -72,7 +72,7 @@ protected: for (auto&& shape : inputDynamicShapes) params.push_back(std::make_shared(inType, shape)); - auto shapeOf = std::make_shared(params[0], ngraph::element::i32); + auto shapeOf = std::make_shared(params.front(), ngraph::element::i32); function = makeNgraphFunction(netPrecision, params, shapeOf, "ShapeOf"); } @@ -85,29 +85,6 @@ TEST_P(ShapeOfLayerCPUTest, CompareWithRefs) { namespace { -/* CPU PARAMS */ -std::vector getCpuInfoForDimsCount(const size_t dimsCount = 3) { - std::vector resCPUParams; - if (dimsCount == 5) { - resCPUParams.push_back(CPUSpecificParams{{nCdhw16c}, {x}, {}, {}}); - resCPUParams.push_back(CPUSpecificParams{{nCdhw8c}, {x}, {}, {}}); - resCPUParams.push_back(CPUSpecificParams{{ncdhw}, {x}, {}, {}}); - resCPUParams.push_back(CPUSpecificParams{{ndhwc}, {x}, {}, {}}); - } else if (dimsCount == 4) { - resCPUParams.push_back(CPUSpecificParams{{nChw16c}, {x}, {}, {}}); - resCPUParams.push_back(CPUSpecificParams{{nChw8c}, {x}, {}, {}}); - resCPUParams.push_back(CPUSpecificParams{{nchw}, {x}, {}, {}}); - resCPUParams.push_back(CPUSpecificParams{{nhwc}, {x}, {}, {}}); - } else { - resCPUParams.push_back(CPUSpecificParams{{nCw16c}, {x}, {}, {}}); - resCPUParams.push_back(CPUSpecificParams{{nCw8c}, {x}, {}, {}}); - resCPUParams.push_back(CPUSpecificParams{{abc}, {x}, {}, {}}); - resCPUParams.push_back(CPUSpecificParams{{acb}, {x}, {}, {}}); - } - - return resCPUParams; -} - const std::vector netPrecisions = { ElementType::f32, ElementType::bf16, @@ -119,17 +96,9 @@ std::vector inShapesDynamic3d = { { {-1, -1, -1}, { - { 8, 5, 4 }, - { 8, 5, 3 }, - { 8, 5, 2 } - } - }, - { - {-1, -1, -1}, - { - { 1, 2, 4 }, - { 1, 2, 3 }, - { 1, 2, 2 } + { 8, 16, 4 }, + { 8, 16, 3 }, + { 8, 16, 2 } } } }; @@ -138,36 +107,20 @@ std::vector inShapesDynamic4d = { { {-1, -1, -1, -1}, { - { 8, 5, 3, 4 }, - { 8, 5, 3, 3 }, - { 8, 5, 3, 2 } + { 8, 16, 3, 4 }, + { 8, 16, 3, 3 }, + { 8, 16, 3, 2 } } }, - { - {-1, -1, -1, -1}, - { - { 1, 2, 3, 4 }, - { 1, 2, 3, 3 }, - { 1, 2, 3, 2 } - } - } }; std::vector inShapesDynamic5d = { { { -1, -1, -1, -1, -1 }, { - { 8, 5, 3, 2, 4 }, - { 8, 5, 3, 2, 3 }, - { 8, 5, 3, 2, 2 } - } - }, - { - {-1, -1, -1, -1, -1}, - { - { 1, 2, 3, 4, 4 }, - { 1, 2, 3, 4, 3 }, - { 1, 2, 3, 4, 2 } + { 8, 16, 3, 2, 4 }, + { 8, 16, 3, 2, 3 }, + { 8, 16, 3, 2, 2 } } } }; @@ -175,19 +128,19 @@ const auto params5dDynamic = ::testing::Combine( ::testing::Combine( ::testing::ValuesIn(inShapesDynamic5d), ::testing::ValuesIn(netPrecisions)), - ::testing::ValuesIn(getCpuInfoForDimsCount(5))); + ::testing::Values(emptyCPUSpec)); const auto params4dDynamic = ::testing::Combine( ::testing::Combine( ::testing::ValuesIn(inShapesDynamic4d), ::testing::ValuesIn(netPrecisions)), - ::testing::ValuesIn(getCpuInfoForDimsCount(4))); + ::testing::Values(emptyCPUSpec)); const auto params3dDynamic = ::testing::Combine( ::testing::Combine( ::testing::ValuesIn(inShapesDynamic3d), ::testing::ValuesIn(netPrecisions)), - ::testing::ValuesIn(getCpuInfoForDimsCount(3))); + ::testing::Values(emptyCPUSpec)); // We don't check static case, because of constant folding INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf3dDynamicLayoutTest, ShapeOfLayerCPUTest, diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp new file mode 100644 index 00000000000..bf5dac2d822 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/concat_sdp.cpp @@ -0,0 +1,228 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include +#include + +#include "ov_models/builders.hpp" +#include "ov_models/utils/ov_helpers.hpp" +#include "shared_test_classes/base/layer_test_utils.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" +#include "test_utils/cpu_test_utils.hpp" +#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp" + +using namespace ov::test; +using namespace ngraph; +using namespace CPUTestUtils; +using namespace InferenceEngine; + +namespace SubgraphTestsDefinitions { + +using ConcatSDPTestParams = std::tuple, + bool // has ShapeOf + >; +// Subgraph: +/* Parameter + * | + * Parameter ReadValue | ReadValue Parameter + * \ / | \ / + * \ / | \ / + * Concat | Concat + * / \ | / \ + * / \ | / \ + * / \ | / \ + * Assign ScaledDotProductAttention Assign + * | + * Add + * | + * Result + */ + +class ConcatSDPTest : public testing::WithParamInterface, virtual public ov::test::SubgraphBaseTest, public CPUTestsBase { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj) { + ElementType inType; + std::vector inputShapes; + bool hasShapeof; + std::tie(inType, inputShapes, hasShapeof) = obj.param; + std::ostringstream result; + result << "IS="; + for (const auto& shape : inputShapes) { + result << ov::test::utils::partialShape2str({shape.first}) << "_"; + } + result << "TS="; + for (const auto& shape : inputShapes) { + result << "("; + if (!shape.second.empty()) { + for (const auto& itr : shape.second) { + result << ov::test::utils::vec2str(itr); + } + } + result << ")_"; + } + result << "Prc=" << inType << "_"; + result << "HasShapeOf=" << hasShapeof; + return result.str(); + } + + void SetUp() override { + ElementType inType; + std::vector inputShapes; + bool hasShapeOf; + std::tie(inType, inputShapes, hasShapeOf) = this->GetParam(); + targetDevice = ov::test::utils::DEVICE_CPU; + rel_threshold = 1e-4f; + if (inType == ElementType::bf16) { + configuration.insert({"ENFORCE_BF16", "YES"}); + rel_threshold = 0.01f; + } + init_input_shapes(inputShapes); + ov::ParameterVector inputParams; + // q,k,v + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[0])); + inputParams[0]->set_friendly_name("q"); + inputParams[1]->set_friendly_name("k"); + inputParams[2]->set_friendly_name("v"); + // pastkv init_cost + inputParams.push_back(std::make_shared(inType, inputDynamicShapes[1])); + auto var_k = std::make_shared( + ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastk"}); + auto pastk = std::make_shared(inputParams[3], var_k); + pastk->set_friendly_name("pastk_r"); + auto var_v = std::make_shared( + ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastv"}); + auto pastv = std::make_shared(inputParams[3], var_v); + pastv->set_friendly_name("pastv_r"); + std::shared_ptr pastk_shapeof, pastv_shapeof; + if (hasShapeOf) { + pastk_shapeof = std::make_shared(pastk); + pastv_shapeof = std::make_shared(pastv); + } + auto concatK = std::make_shared(OutputVector{pastk, inputParams[1]}, 2); + auto concatV = std::make_shared(OutputVector{pastv, inputParams[2]}, 2); + auto sdp = std::make_shared(inputParams[0], concatK, concatV, false); + sdp->set_friendly_name("mha"); + auto add = std::make_shared(sdp, op::v0::Constant::create(inType, {1}, {1.0f})); + auto pastk_assign = std::make_shared(concatK, var_k); + auto pastv_assign = std::make_shared(concatV, var_v); + pastk_assign->set_friendly_name("pastk_w"); + pastv_assign->set_friendly_name("pastv_w"); + + ResultVector results{std::make_shared(add)}; + if (hasShapeOf) { + results.push_back(std::make_shared(pastk_shapeof)); + results.push_back(std::make_shared(pastv_shapeof)); + } + SinkVector sinks{pastk_assign, pastv_assign}; + function = std::make_shared(results, sinks, inputParams, "ConcatSDP"); + targetDevice = ov::test::utils::DEVICE_CPU; + + functionRefs = function->clone(); + pass::Manager manager; + // decompose ScaledDotProductAttention + manager.register_pass(); + manager.run_passes(functionRefs); + } + void generate_inputs(const std::vector& targetInputStaticShapes) override { + std::vector shapes(4); + shapes[0] = targetInputStaticShapes[0]; + shapes[1] = targetInputStaticShapes[0]; + shapes[2] = targetInputStaticShapes[0]; + shapes[3] = targetInputStaticShapes[1]; + SubgraphBaseTest::generate_inputs(shapes); + } + template + void strided_iota(IT first, size_t n, T value, T stride) { + for (size_t i = 0; i < n; i++) { + *first++ = value; + value += stride; + } + } + void generate(int idx, const std::vector& targetInputStaticShapes) { + inputs.clear(); + auto create_input = [this] (std::shared_ptr param, ov::Shape shape, float val) { + if (param->get_element_type() == element::f32) { + ov::Tensor t{ov::element::f32, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } else { + ov::Tensor t{ov::element::bf16, shape}; + strided_iota(static_cast(t.data()), t.get_size(), val, 0.1f); + inputs.insert({param, t}); + } + }; + // q, k, v + create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f); + create_input(function->get_parameters()[1], targetInputStaticShapes[0], idx + 2.0f); + create_input(function->get_parameters()[2], targetInputStaticShapes[0], idx + 3.0f); + create_input(function->get_parameters()[3], targetInputStaticShapes[1], idx + 4.0f); + } + void prepare() { + compile_model(); + inferRequest = compiledModel.create_infer_request(); + ASSERT_TRUE(inferRequest); + } + void reset() { + for (auto&& state : inferRequest.query_state()) { + state.reset(); + } + inferRequest = ov::InferRequest(); + } + std::vector run_test(std::shared_ptr model) { + function = model; + prepare(); + std::vector outputs; + int idx = 0; + for (auto&& shapes : targetStaticShapes) { + generate(idx++, shapes); + for (const auto& input : inputs) { + inferRequest.set_tensor(input.first, input.second); + } + inferRequest.infer(); + auto outputTensor = inferRequest.get_output_tensor(0); + ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()}; + outputTensor.copy_to(copy); + outputs.push_back(copy); + } + reset(); + + return outputs; + } +}; + +TEST_P(ConcatSDPTest, CompareWithRefs) { + auto actualOutputs = run_test(function); + CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1); + CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0); + CheckNumberOfNodesWithType(compiledModel, "Reorder", 0); + auto expectedOutputs = run_test(functionRefs); + CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0); + for (size_t i = 0; i < actualOutputs.size(); i++) { + ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold); + } +} + +namespace { +const std::vector> inputShapes = { + // dynamic batch + { + // B, H, L1, S + {{1, 8, -1, 64}, {{1, 8, 10, 64}, {1, 8, 1, 64}, {1, 8, 1, 64}, {1, 8, 20, 64}, {1, 8, 1, 64}}}, + // B, H, L0, S + {{1, 8, -1, 64}, {{1, 8, 0, 64}, {1, 8, 10, 64}, {1, 8, 11, 64}, {1, 8, 12, 64}, {1, 8, 32, 64}}}, + }, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTest, + ConcatSDPTest, + ::testing::Combine(::testing::Values(ElementType::f32), + ::testing::ValuesIn(inputShapes), + ::testing::Values(true, false)), + ConcatSDPTest::getTestCaseName); + +} // namespace +} // namespace SubgraphTestsDefinitions diff --git a/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/shapeof_any_layout.cpp b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/shapeof_any_layout.cpp new file mode 100644 index 00000000000..68439615ec5 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/subgraph_tests/src/shapeof_any_layout.cpp @@ -0,0 +1,201 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "test_utils/cpu_test_utils.hpp" + +#include "ov_models/builders.hpp" +#include "ov_models/utils/ov_helpers.hpp" +#include "shared_test_classes/base/ov_subgraph.hpp" + +using namespace InferenceEngine; +using namespace CPUTestUtils; + +using InputShape = ov::test::InputShape; +using ElementType = ov::element::Type_t; + +namespace SubgraphTestsDefinitions { + +// ┌────────┐ +// │ Param │ +// └───┬────┘ +// │ +// │ +// │ +// ┌───┴────┐ To simulate different layouts +// │ Eltwise│ ◄───────────────────────────── +// └───┬────┘ +// │ No Reorders are expected +// │ ◄─────────────────────────── +// │ +// ┌───┴────┐ +// │ShapeOf │ +// └───┬────┘ +// │ +// │ +// │ +// ┌───┴────┐ +// │ Output │ +// └────────┘ + +typedef std::tuple< + InputShape, + ElementType // Net precision +> ShapeOfAnyLayoutParams; + +typedef std::tuple< + ShapeOfAnyLayoutParams, + CPUSpecificParams +> ShapeOfAnyLayoutCPUTestParamsSet; + +class ShapeOfAnyLayoutCPUTest : public testing::WithParamInterface, + virtual public ov::test::SubgraphBaseTest, public CPUTestsBase { +public: + static std::string getTestCaseName(testing::TestParamInfo obj) { + SubgraphTestsDefinitions::ShapeOfAnyLayoutParams basicParamsSet; + CPUSpecificParams cpuParams; + std::tie(basicParamsSet, cpuParams) = obj.param; + ElementType netPr; + InputShape inputShape; + + std::tie(inputShape, netPr) = basicParamsSet; + std::ostringstream result; + result << "ShapeOfTest_"; + result << std::to_string(obj.index) << "_"; + result << "Prec=" << netPr << "_"; + result << CPUTestsBase::getTestCaseName(cpuParams) << "_"; + result << "IS="; + result << ov::test::utils::partialShape2str({inputShape.first}) << "_"; + result << "TS=("; + for (const auto& shape : inputShape.second) { + result << ov::test::utils::vec2str(shape) << "_"; + } + result << ")"; + return result.str(); + } +protected: + void SetUp() override { + targetDevice = ov::test::utils::DEVICE_CPU; + + ShapeOfAnyLayoutParams basicParamsSet; + CPUSpecificParams cpuParams; + std::tie(basicParamsSet, cpuParams) = this->GetParam(); + std::vector eltwiseInFmts, eltwiseOutFmts; + std::tie(eltwiseInFmts, eltwiseOutFmts, priority, selectedType) = cpuParams; + + auto netPrecision = ElementType::undefined; + InputShape inputShape; + std::tie(inputShape, netPrecision) = basicParamsSet; + init_input_shapes({inputShape}); + + inType = ov::element::Type(netPrecision); + outType = ElementType::i32; + selectedType = makeSelectedTypeStr("ref", inType); + + ov::ParameterVector params; + for (auto&& shape : inputDynamicShapes) + params.push_back(std::make_shared(inType, shape)); + + //make a stub eltwise node to enforce layout, since ShapeOf just mimic any input layout + auto eltwise = ngraph::builder::makeActivation(params[0], inType, ov::test::utils::ActivationTypes::Relu); + eltwise->get_rt_info() = makeCPUInfo(eltwiseInFmts, eltwiseOutFmts, {}); + + auto shapeOf = std::make_shared(eltwise, ngraph::element::i32); + + function = makeNgraphFunction(netPrecision, params, shapeOf, "ShapeOf"); + } +}; + +TEST_P(ShapeOfAnyLayoutCPUTest, CompareWithRefs) { + run(); + CheckPluginRelatedResults(compiledModel, "ShapeOf"); + CheckNumberOfNodesWithType(compiledModel, "Reorder", 1); +} + +namespace { + +/* CPU PARAMS */ +std::vector getCpuInfoForDimsCount(const size_t dimsCount = 3) { + std::vector resCPUParams; + const bool avx512_target = with_cpu_x86_avx512f(); + + if (dimsCount == 5) { + auto blocked_format = avx512_target ? nCdhw16c : nCdhw8c; + resCPUParams.push_back(CPUSpecificParams{{blocked_format}, {blocked_format}, {}, {}}); + resCPUParams.push_back(CPUSpecificParams{{ndhwc}, {ndhwc}, {}, {}}); + } else if (dimsCount == 4) { + auto blocked_format = avx512_target ? nChw16c : nChw8c; + resCPUParams.push_back(CPUSpecificParams{{blocked_format}, {blocked_format}, {}, {}}); + resCPUParams.push_back(CPUSpecificParams{{nhwc}, {nhwc}, {}, {}}); + } else { + auto blocked_format = avx512_target ? nCw16c : nCw8c; + resCPUParams.push_back(CPUSpecificParams{{blocked_format}, {blocked_format}, {}, {}}); + resCPUParams.push_back(CPUSpecificParams{{acb}, {acb}, {}, {}}); + } + + return resCPUParams; +} + +const std::vector netPrecisions = { + ElementType::f32 +}; + +std::vector inShapesDynamic3d = { + { + {-1, 16, -1}, + { + { 8, 16, 4 }, + { 8, 16, 3 }, + { 8, 16, 2 } + } + } +}; + +std::vector inShapesDynamic4d = { + { + {-1, 16, -1, -1}, + { + { 8, 16, 3, 4 }, + { 8, 16, 3, 3 }, + { 8, 16, 3, 2 } + } + }, +}; + +std::vector inShapesDynamic5d = { + { + { -1, 16, -1, -1, -1 }, + { + { 8, 16, 3, 2, 4 }, + { 8, 16, 3, 2, 3 }, + { 8, 16, 3, 2, 2 } + } + } +}; +const auto params5dDynamic = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(inShapesDynamic5d), + ::testing::ValuesIn(netPrecisions)), + ::testing::ValuesIn(getCpuInfoForDimsCount(5))); + +const auto params4dDynamic = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(inShapesDynamic4d), + ::testing::ValuesIn(netPrecisions)), + ::testing::ValuesIn(getCpuInfoForDimsCount(4))); + +const auto params3dDynamic = ::testing::Combine( + ::testing::Combine( + ::testing::ValuesIn(inShapesDynamic3d), + ::testing::ValuesIn(netPrecisions)), + ::testing::ValuesIn(getCpuInfoForDimsCount(3))); + +// We don't check static case, because of constant folding +INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf3dAnyLayoutTest, ShapeOfAnyLayoutCPUTest, + params3dDynamic, ShapeOfAnyLayoutCPUTest::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf4dAnyLayoutTest, ShapeOfAnyLayoutCPUTest, + params4dDynamic, ShapeOfAnyLayoutCPUTest::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf5dAnyLayoutTest, ShapeOfAnyLayoutCPUTest, + params5dDynamic, ShapeOfAnyLayoutCPUTest::getTestCaseName); +} // namespace +} // namespace SubgraphTestsDefinitions diff --git a/src/plugins/intel_cpu/tests/unit/graph/scaled_attn.cpp b/src/plugins/intel_cpu/tests/unit/graph/scaled_attn.cpp new file mode 100644 index 00000000000..081d9ddfc7b --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/graph/scaled_attn.cpp @@ -0,0 +1,161 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "common_test_utils/common_utils.hpp" +#include "cache/multi_cache.h" +#include "ov_models/builders.hpp" +#include "nodes/scaled_attn.h" +#include "nodes/input.h" +#include "graph.h" +#include "cpu_tensor.h" + +using namespace ov::intel_cpu; + +TEST(ScaledAttnGraphTest, smoke_Check_Scaled_Concat_Noplace) { + auto build_graph = [](const ov::Shape& shape, float* qkv_val, float* past_kv_val) { + auto qkv = ov::op::v0::Constant::create(ov::element::f32, shape, qkv_val); + qkv->set_friendly_name("qkv_const"); + auto pastkv = ov::op::v0::Constant::create(ov::element::f32, shape, past_kv_val); + pastkv->set_friendly_name("pastkv_const"); + // only need a dynamic parameter but its value will not be used + auto attn = std::make_shared(ov::element::f32, ov::PartialShape{-1}); + attn->set_friendly_name("attn"); + + ov::intel_cpu::ScaledDotProductAttentionStub::Config config; + config.fuse_concat = true; + config.is_causal = true; + auto sdpa = std::make_shared(ov::OutputVector{qkv, qkv, qkv, attn, pastkv, pastkv}, config); + auto out_qkv = std::make_shared(sdpa->output(0)); + out_qkv->set_friendly_name("qkv"); + auto out_pastk = std::make_shared(sdpa->output(1)); + out_pastk->set_friendly_name("pastk"); + auto out_pastv = std::make_shared(sdpa->output(2)); + out_pastv->set_friendly_name("pastv"); + + std::unordered_set nodes_set; + std::vector graph_edges; + + auto add_edge = [&](const NodePtr& parent, const NodePtr& child, size_t parentPort, size_t childPort) -> void { + auto edge = std::make_shared(parent, child, parentPort, childPort); + child->addEdge(edge); + graph_edges.push_back(edge); + nodes_set.insert(parent); + nodes_set.insert(child); + }; + + //create graph context + Config conf; + conf.rtCacheCapacity = 0; + auto context = std::make_shared(conf, nullptr, nullptr, false); + + auto qkv_node = std::make_shared(qkv, context); + auto pastkv_node = std::make_shared(pastkv, context); + auto attn_node = std::make_shared(attn, context); + auto sdpa_node = std::make_shared(sdpa, context); + auto out_qkv_node = std::make_shared(out_qkv, context); + auto out_pastk_node = std::make_shared(out_pastk, context); + auto out_pastv_node = std::make_shared(out_pastv, context); + + add_edge(qkv_node, sdpa_node, 0, 0); + add_edge(qkv_node, sdpa_node, 0, 1); + add_edge(qkv_node, sdpa_node, 0, 2); + add_edge(attn_node, sdpa_node, 0, 3); + add_edge(pastkv_node, sdpa_node, 0, 4); + add_edge(pastkv_node, sdpa_node, 0, 5); + add_edge(sdpa_node, out_qkv_node, 0, 0); + add_edge(sdpa_node, out_pastk_node, 1, 0); + add_edge(sdpa_node, out_pastv_node, 2, 0); + + std::vector graph_nodes(nodes_set.begin(), nodes_set.end()); + + Graph graph; + graph.CreateGraph(graph_nodes, graph_edges, context, "test_graph"); + return graph; + }; + + auto run_graph = [] (Graph& graph) { + graph.GetInputNodesMap().begin()->second->redefineOutputMemory(0, {1}); + + for (auto& node : graph.GetNodes()) { + if (node->isDynamicNode()) { + node->updateShapes(); + node->updateDynamicParams(); + } + } + graph.Infer(); + }; + + auto check_graph = [] (Graph& graph, std::map>& expected) { + auto& outputNodesMap = graph.GetOutputNodesMap(); + auto is_same = [] (float a, float b) { + return std::abs(a - b) < 0.0001f; + }; + for (auto &outputMap : outputNodesMap) { + auto name = outputMap.first; + if (expected.count(name) == 0) { + continue; + } + auto node = outputMap.second; + auto parentEdge = node->getParentEdgeAt(0); + const auto& memory = parentEdge->getMemoryPtr(); + auto size = memory->getSize() / sizeof(float); + auto p = reinterpret_cast(memory->getData()); + for (size_t i = 0; i < size; i++) { + ASSERT_EQ(is_same(p[i], expected.at(name).first[i]), true); + } + ASSERT_EQ(memory->getShape(), ov::intel_cpu::Shape(expected.at(name).second)); + } + }; + + auto find_node_type = [](const Graph& graph, Type type) -> NodePtr { + auto&& nodes = graph.GetNodes(); + auto itr = + std::find_if(nodes.begin(), nodes.end(), [=](const NodePtr& node){ return type == node->getType(); }); + + if (itr == nodes.end()) { + return nullptr; + } + + return (*itr); + }; + + auto strided_iota = [] (float* first, size_t n, float value, float stride) { + for (size_t i = 0; i < n; i++) { + *first++ = value; + value += stride; + } + }; + + ov::Shape shape{1, 1, 8, 8}; + const size_t elements_count = std::accumulate(shape.begin(), shape.end(), size_t{1}, std::multiplies()); + std::vector val(elements_count * 2); + strided_iota(val.data(), val.size(), -10.0f, 0.1f); + auto graph = build_graph(shape, val.data() + elements_count, val.data()); + run_graph(graph); + // if no inplace, the pastk and pastv will concat, check shape and value + ov::Shape expectedShape(shape); + expectedShape[2] *= 2; + std::map> expected{ + {"pastk", std::make_pair(val.data(), expectedShape)}, + {"pastv", std::make_pair(val.data(), expectedShape)}}; + check_graph(graph, expected); + auto spd = find_node_type(graph, Type::ScaledDotProductAttention)->getSelectedPrimitiveDescriptor(); + ASSERT_EQ(spd->getConfig().outConfs[1].inPlace(), -1); + ASSERT_EQ(spd->getConfig().outConfs[2].inPlace(), -1); +} \ No newline at end of file diff --git a/src/plugins/intel_cpu/tests/unit/transformations/state_concat_sdpa.cpp b/src/plugins/intel_cpu/tests/unit/transformations/state_concat_sdpa.cpp new file mode 100644 index 00000000000..6f30e3390f4 --- /dev/null +++ b/src/plugins/intel_cpu/tests/unit/transformations/state_concat_sdpa.cpp @@ -0,0 +1,105 @@ +// Copyright (C) 2018-2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ov_test_utils.hpp" + +using namespace testing; +using namespace ov::intel_cpu; +using namespace ov; + +static std::shared_ptr makeSDPA(const ov::PartialShape& inputShape, bool isRef = false, bool hasConvert = false) { + auto q = std::make_shared(element::f32, inputShape); + auto k = std::make_shared(element::f32, inputShape); + auto v = std::make_shared(element::f32, inputShape); + auto init = std::make_shared(element::f32, inputShape); + auto var_k = std::make_shared( + ov::op::util::VariableInfo{inputShape, element::f32, "pastk"}); + std::shared_ptr pastk = std::make_shared(k, var_k); + auto var_v = std::make_shared( + ov::op::util::VariableInfo{inputShape, element::f32, "pastv"}); + std::shared_ptr pastv = std::make_shared(v, var_v); + Output concatK, concatV, sdp; + if (hasConvert) { + pastk = std::make_shared(pastk, element::f32); + pastv = std::make_shared(pastv, element::f32); + } + if (isRef) { + ov::intel_cpu::ScaledDotProductAttentionStub::Config config; + config.fuse_concat = true; + auto new_node = std::make_shared(OutputVector{q, k, v, pastk, pastv}, config); + sdp = new_node->output(0); + concatK = new_node->output(1); + concatV = new_node->output(2); + } else { + concatK = std::make_shared(OutputVector{pastk, k}, 2); + concatV = std::make_shared(OutputVector{pastv, v}, 2); + sdp = std::make_shared(q, concatK, concatV, false); + } + if (hasConvert) { + concatK = std::make_shared(concatK, element::f32); + concatV = std::make_shared(concatV, element::f32); + } + auto pastk_assign = std::make_shared(concatK, var_k); + auto pastv_assign = std::make_shared(concatV, var_v); + auto add = std::make_shared(sdp, op::v0::Constant::create(element::f32, {1}, {1.0f})); + + ResultVector results{std::make_shared(add)}; + SinkVector sinks{pastk_assign, pastv_assign}; + return std::make_shared(results, sinks, ParameterVector{q, k, v, init}, "ConcatSDP"); +} + +TEST(TransformationTests, StateConcatSDPA) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + using namespace ov; + auto inputShape = ov::PartialShape{-1, 8, -1, 64}; + { + f = makeSDPA(inputShape); + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + } + //construct ref interaction + { + f_ref = makeSDPA(inputShape, true); + } + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; + } +} + +TEST(TransformationTests, StateConcatSDPAWithConvert) { + std::shared_ptr f(nullptr), f_ref(nullptr); + { + using namespace ov; + auto inputShape = ov::PartialShape{-1, 8, -1, 64}; + { + f = makeSDPA(inputShape, false, true); + pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(f); + } + //construct ref interaction + { + f_ref = makeSDPA(inputShape, true, true); + } + auto res = compare_functions(f, f_ref); + ASSERT_TRUE(res.first) << res.second; + } +}