[CPU] Add optimized memory management for SDPA KV cache (#21242)

This commit is contained in:
Maksim Kutakov 2023-11-30 11:07:43 +01:00 committed by GitHub
parent 718b5a60bf
commit 405d97e4a5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
32 changed files with 1952 additions and 364 deletions

View File

@ -245,6 +245,49 @@ void MemoryMngrWithReuse::destroy(void *ptr) {
dnnl::impl::free(ptr);
}
void* MemoryMngrRealloc::getRawPtr() const noexcept {
return m_data.get();
}
void MemoryMngrRealloc::setExtBuff(void *ptr, size_t size) {
m_useExternalStorage = true;
m_memUpperBound = size;
m_data = decltype(m_data)(ptr, release);
}
bool MemoryMngrRealloc::resize(size_t size) {
constexpr int cacheLineSize = 64;
constexpr size_t growFactor = 2;
bool sizeChanged = false;
if (size > m_memUpperBound) {
size *= growFactor;
void *ptr = dnnl::impl::malloc(size, cacheLineSize);
if (!ptr) {
OPENVINO_THROW("Failed to allocate ", size, " bytes of memory");
}
if (auto src = m_data.get()) {
std::memcpy(ptr, src, m_memUpperBound);
}
m_memUpperBound = size;
m_useExternalStorage = false;
m_data = decltype(m_data)(ptr, destroy);
sizeChanged = true;
}
return sizeChanged;
}
bool MemoryMngrRealloc::hasExtBuffer() const noexcept {
return m_useExternalStorage;
}
void MemoryMngrRealloc::release(void *ptr) {}
void MemoryMngrRealloc::destroy(void *ptr) {
dnnl::impl::free(ptr);
}
void* DnnlMemoryMngr::getRawPtr() const noexcept {
return m_pMemMngr->getRawPtr();
}

View File

@ -89,6 +89,23 @@ private:
static void destroy(void *ptr);
};
class MemoryMngrRealloc : public IMemoryMngr {
public:
MemoryMngrRealloc() : m_data(nullptr, release) {}
void* getRawPtr() const noexcept override;
void setExtBuff(void* ptr, size_t size) override;
bool resize(size_t size) override;
bool hasExtBuffer() const noexcept override;
private:
bool m_useExternalStorage = false;
size_t m_memUpperBound = 0ul;
std::unique_ptr<void, void (*)(void *)> m_data;
static void release(void *ptr);
static void destroy(void *ptr);
};
class IMemoryMngrObserver : public IMemoryMngr {
public:
virtual void registerMemory(Memory* memPtr) = 0;

View File

@ -214,6 +214,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
{ "Unique", Type::Unique},
{ "Ngram", Type::Ngram},
{ "ScaledDotProductAttention", Type::ScaledDotProductAttention},
{ "ScaledDotProductAttentionStub", Type::ScaledDotProductAttention},
{ "RoPE", Type::RoPE},
};
return type_to_name_tbl;

View File

@ -6,6 +6,7 @@
#include "transformations/cpu_opset/common/op/fully_connected.hpp"
#include "transformations/cpu_opset/common/op/leaky_relu.hpp"
#include "transformations/cpu_opset/common/op/power_static.hpp"
#include "transformations/cpu_opset/common/op/sdp.hpp"
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
#include "transformations/cpu_opset/common/op/ngram.hpp"
#include "transformations/cpu_opset/x64/op/mha.hpp"
@ -59,6 +60,7 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
NGRAPH_OP(NgramNode, ov::intel_cpu)
NGRAPH_OP_X64(MHANode, ov::intel_cpu)
NGRAPH_OP_X64(InteractionNode, ov::intel_cpu)
NGRAPH_OP_X64(ScaledDotProductAttentionStub, ov::intel_cpu)
#undef NGRAPH_OP
return opset;

View File

@ -1441,7 +1441,7 @@ void Graph::GetPerfData(std::vector<ov::ProfilingInfo>& perfMap) const {
}
}
void Graph::RemoveEdge(EdgePtr& edge) {
void Graph::RemoveEdge(const EdgePtr& edge) {
for (auto it = graphEdges.begin(); it != graphEdges.end(); it++) {
if ((*it) == edge) {
edge->drop();
@ -1881,9 +1881,9 @@ void Graph::resolveInPlaceDirection(const NodePtr& node) const {
void Graph::SearchInternalStateNodes() {
for (auto&& node : graphNodes) {
if (node->getType() == Type::MemoryInput) {
auto cur_node = std::dynamic_pointer_cast<node::MemoryInput>(node);
auto cur_node = std::dynamic_pointer_cast<node::MemoryStateNode>(node);
if (!cur_node) {
OPENVINO_THROW("Cannot cast ", node->getName(), " to MemoryInput");
OPENVINO_THROW("Cannot cast ", node->getName(), " to MemoryStateNode");
}
internalStateNodes.insert({cur_node->getId(), cur_node});
}

View File

@ -28,7 +28,7 @@ namespace intel_cpu {
class SyncInferRequest;
namespace node {
class MemoryNode;
class MemoryStateNode;
} // namespace node
class Graph {
@ -123,7 +123,7 @@ public:
void RemoveDroppedNodes();
void RemoveDroppedEdges();
void RemoveEdge(EdgePtr& edge);
void RemoveEdge(const EdgePtr& edge);
void DropNode(const NodePtr& node);
void DropDWConvNode(const NodePtr& node);
@ -197,7 +197,7 @@ public:
}
Status getStatus() const {return status;}
const std::unordered_map<std::string, std::shared_ptr<node::MemoryNode>>&
const std::unordered_map<std::string, std::shared_ptr<node::MemoryStateNode>>&
getInternalStateNodes() const {
return internalStateNodes;
}
@ -259,7 +259,7 @@ private:
std::map<std::string, NodePtr> outputNodesMap;
std::unordered_map<std::string, ProxyMemoryMngrPtr> outputNodesMemMngrMap;
std::unordered_map<std::string, std::shared_ptr<node::MemoryNode>> internalStateNodes;
std::unordered_map<std::string, std::shared_ptr<node::MemoryStateNode>> internalStateNodes;
// these node pointers (from graphNodes) are to avoid regular checking for
// constantness of nodes in Infer methods and calls of

View File

@ -22,6 +22,7 @@
#include "nodes/reduce.h"
#include "nodes/input.h"
#include "nodes/rnn.h"
#include "nodes/memory.hpp"
#include "nodes/common/cpu_convert.h"
#include "onednn/dnnl.h"
@ -182,6 +183,18 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
RemoveSameConvert(graph);
graph.RemoveDroppedNodes();
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveMemoryInputConvert");
RemoveMemoryInputConvert(graph);
graph.RemoveDroppedNodes();
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveConvertMemoryOutput");
RemoveConvertMemoryOutput(graph);
graph.RemoveDroppedNodes();
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "MatchSdpaKvCache");
MatchSdpaKvCache(graph);
graph.RemoveDroppedNodes();
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveDroppedEdges");
graph.RemoveDroppedEdges();
}
@ -2710,5 +2723,153 @@ void GraphOptimizer::RemoveSameConvert(Graph& graph) {
}
}
void GraphOptimizer::RemoveMemoryInputConvert(Graph &graph) {
auto& graphNodes = graph.GetNodes();
auto isSuitableNode = [](const NodePtr& node) {
if (Type::Convert != node->getType()) {
return false;
}
auto parent = node->getParentEdgeAt(0)->getParent();
if (Type::MemoryInput != parent->getType()) {
return false;
}
return true;
};
for (size_t i = 0; i < graphNodes.size(); i++) {
auto node = graphNodes[i];
if (!isSuitableNode(node)) {
continue;
}
graph.DropNode(node);
}
}
void GraphOptimizer::RemoveConvertMemoryOutput(Graph &graph) {
auto& graphNodes = graph.GetNodes();
auto isSuitableNode = [](const NodePtr& node) {
if (Type::Convert != node->getType()) {
return false;
}
auto&& childEdges = node->getChildEdgesAtPort(0);
for (auto&& edge : childEdges) {
if (Type::MemoryOutput != edge->getChild()->getType()) {
return false;
}
}
return true;
};
for (size_t i = 0; i < graphNodes.size(); i++) {
auto node = graphNodes[i];
if (!isSuitableNode(node)) {
continue;
}
graph.DropNode(node);
}
}
void GraphOptimizer::MatchSdpaKvCache(Graph &graph) {
auto& graphNodes = graph.GetNodes();
auto isSuitableMemInput = [](const NodePtr& node) -> bool {
if (Type::MemoryInput != node->getType()) {
return false;
}
NodePtr childSdpa = nullptr;
auto&& childEdges = node->getChildEdgesAtPort(0);
for (auto&& item : childEdges) {
auto childNode = item->getChild();
if (!one_of(childNode->getType(), Type::ScaledDotProductAttention, Type::ShapeOf)) {
return false;
}
if (Type::ScaledDotProductAttention == childNode->getType()) {
if (childSdpa && childSdpa != childNode) {
//only one child SDPA supported
return false;
}
childSdpa = childNode;
}
}
CPU_GRAPH_OPTIMIZER_SCOPE(MatchSdpaKvCache_isSuitableMemInput);
auto memInputNode = std::dynamic_pointer_cast<node::MemoryInputBase>(node);
OPENVINO_ASSERT(memInputNode, "MemoryInput node ", node->getName(), " has unexpected dynamic type");
auto& memOutputNode = memInputNode->getOutputNode();
auto memOutputParent = memOutputNode.getParentEdgeAt(0)->getParent();
if (memOutputParent != childSdpa) {
return false;
}
return true;
};
for (size_t i = 0; i < graphNodes.size(); i++) {
auto node = graphNodes[i];
if (!isSuitableMemInput(node)) {
continue;
}
CPU_GRAPH_OPTIMIZER_SCOPE(MatchSdpaKvCache_Node);
// Node is already modified
if (auto sdpaMemInput = std::dynamic_pointer_cast<node::MemoryInputSDPA>(node)) {
continue;
}
auto memInputNode = std::dynamic_pointer_cast<node::MemoryInputBase>(node);
OPENVINO_ASSERT(memInputNode, "MemoryInput node ", node->getName(), " has unexpected dynamic type");
ov::optional<Shape> input_shape;
ov::optional<ov::element::Type> input_prc;
if (!node->getParentEdges().empty()) {
input_shape = ov::optional<Shape>(node->getInputShapeAtPort(0));
input_prc = ov::optional<ov::element::Type>(node->getOriginalInputPrecisionAtPort(0));
}
auto memInputSdpa = std::make_shared<MemoryInputSDPA>(
memInputNode->getId(),
memInputNode->getName(),
memInputNode->getTypeStr(),
memInputNode->getOutputShapeAtPort(0),
memInputNode->getOriginalOutputPrecisionAtPort(0),
graph.getGraphContext(),
input_shape,
input_prc);
if (!memInputNode->getParentEdges().empty()) {
auto parentEdge = memInputNode->getParentEdgeAt(0);
auto newEdge = std::make_shared<Edge>(parentEdge->getParent(), memInputSdpa, parentEdge->getInputNum(), 0);
memInputSdpa->addEdge(newEdge);
graph.GetEdges().push_back(newEdge);
graph.RemoveEdge(parentEdge);
}
for (auto&& edge : memInputNode->getChildEdgesAtPort(0)) {
auto newEdge = std::make_shared<Edge>(memInputSdpa, edge->getChild(), 0, edge->getOutputNum());
memInputSdpa->addEdge(newEdge);
graph.GetEdges().push_back(newEdge);
graph.RemoveEdge(edge);
}
//link with memory output
auto& memOutput = memInputNode->getOutputNode();
memInputSdpa->registerOutputNode(&memOutput);
graph.GetNodes().push_back(memInputSdpa);
graph.DropNode(memInputNode);
}
}
} // namespace intel_cpu
} // namespace ov

View File

@ -49,6 +49,9 @@ private:
void MergeTransposeAndReorder(Graph &graph);
void reshapeRnnSeq(Graph &graph);
void RemoveSameConvert(Graph &graph);
void RemoveMemoryInputConvert(Graph &graph);
void RemoveConvertMemoryOutput(Graph &graph);
void MatchSdpaKvCache(Graph &graph);
};
} // namespace intel_cpu

View File

@ -61,7 +61,7 @@ void SyncInferRequest::create_infer_request() {
init_tensor(it.first);
}
//create states according to the list of the MemoryNodes
//create states according to the list of the MemoryStateNodes
for (auto&& node : m_graph->getInternalStateNodes()) {
m_memory_states.emplace_back(node.second->makeState());
}

View File

@ -14,13 +14,81 @@ using namespace InferenceEngine;
namespace ov {
namespace intel_cpu {
VariableStateDoubleBuffer::VariableStateDoubleBuffer(std::string name,
const MemBuilder& mem_build,
MemoryDescPtr external_desc,
MemoryCPtr init_val) :
IVariableState{name}, m_external_desc{external_desc} {
reset_prime_mem(mem_build());
reset_second_mem(mem_build());
VariableStateBase::VariableStateBase(const std::string& name, const MemoryDescPtr& external_desc) :
IVariableState{name} , m_external_desc{external_desc} {}
MemoryDescPtr VariableStateBase::to_static(const MemoryDescPtr& desc) {
if (!desc->isDefined()) {
auto&& current_dims = desc->getShape().getDims();
VectorDims new_dims(current_dims.size());
std::transform(current_dims.begin(), current_dims.end(), new_dims.begin(), [](Dim x) {
return x == Shape::UNDEFINED_DIM ? 0 : x; });
return desc->cloneWithNewDims(new_dims, true);
}
return desc;
}
const dnnl::engine& VariableStateBase::get_engine() {
static const dnnl::engine eng(dnnl::engine::kind::cpu, 0);
return eng;
}
void VariableStateBase::set_state(const ov::SoPtr<ov::ITensor>& state) {
m_state = state; // simply to extend the lifetime
auto state_desc = MemoryDescUtils::generateCpuBlockedMemoryDesc(m_state);
const auto& shape = state_desc->getShape();
if (input_mem()->getShape() != shape) {
auto new_desc = internal_desc()->cloneWithNewDims(shape.getStaticDims());
input_mem()->redefineDesc(new_desc);
}
auto src = m_state->data();
Memory mem(get_engine(), state_desc, src);
input_mem()->load(mem);
}
ov::SoPtr<ov::ITensor> VariableStateBase::get_state() const {
const auto& current_dims = internal_state_mem()->getStaticDims();
auto current_ext_desc = m_external_desc->cloneWithNewDims(current_dims);
auto current_internal_desc = internal_state_mem()->getDescPtr();
if (current_ext_desc->isCompatible(*current_internal_desc)) {
return std::make_shared<Tensor>(internal_state_mem());
}
//test precision
{
auto internal_prc = current_internal_desc->getPrecision();
auto tmp_desc = current_ext_desc->cloneWithNewPrecision(internal_prc);
if (tmp_desc->isCompatible(*current_internal_desc)) {
auto mem = std::make_shared<Memory>(get_engine(), current_ext_desc);
size_t elements_to_convert = internal_state_mem()->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
auto external_prc = current_ext_desc->getPrecision();
cpu_convert(internal_state_mem()->getData(), mem->getData(), internal_prc, external_prc, elements_to_convert);
return std::make_shared<Tensor>(mem);
}
}
//reorder
auto mem = std::make_shared<Memory>(get_engine(), current_ext_desc);
mem->load(*(internal_state_mem()));
return std::make_shared<Tensor>(mem);
}
VariableStateDoubleBuffer::VariableStateDoubleBuffer(const std::string& name,
const MemoryPtr& first_buffer,
const MemoryPtr& second_buffer,
const MemoryDescPtr& external_desc,
const MemoryCPtr& init_val) :
VariableStateBase(name, external_desc) {
OPENVINO_ASSERT(first_buffer && second_buffer);
reset_prime_mem(first_buffer);
reset_second_mem(second_buffer);
m_internal_desc = prime_mem()->getDescPtr();
auto&& shape = m_internal_desc->getShape();
//TODO what if by some reason we already have internal static state while the node is dynamic, is it even possible?
@ -38,58 +106,6 @@ VariableStateDoubleBuffer::VariableStateDoubleBuffer(std::string name,
}
}
void VariableStateDoubleBuffer::set_state(const ov::SoPtr<ov::ITensor>& state) {
m_state = state; // simply to extend the lifetime
auto state_desc = MemoryDescUtils::generateCpuBlockedMemoryDesc(m_state);
const auto& shape = state_desc->getShape();
if (prime_mem()->getShape() != shape) {
auto new_desc = m_internal_desc->cloneWithNewDims(shape.getStaticDims());
prime_mem()->redefineDesc(new_desc);
}
auto src = m_state->data();
Memory mem(get_engine(), state_desc, src);
prime_mem()->load(mem);
}
const dnnl::engine& VariableStateDoubleBuffer::get_engine() const {
static const dnnl::engine eng(dnnl::engine::kind::cpu, 0);
return eng;
}
ov::SoPtr<ov::ITensor> VariableStateDoubleBuffer::get_state() const {
//TODO , in general case must be synchronized
const auto& current_dims = prime_mem()->getStaticDims();
auto current_ext_desc = m_external_desc->cloneWithNewDims(current_dims);
auto current_internal_desc = prime_mem()->getDescPtr();
if (current_ext_desc->isCompatible(*current_internal_desc)) {
return std::make_shared<Tensor>(prime_mem());
}
//test precision
{
auto internal_prc = current_internal_desc->getPrecision();
auto tmp_desc = current_ext_desc->cloneWithNewPrecision(internal_prc);
if (tmp_desc->isCompatible(*current_internal_desc)) {
auto mem = std::make_shared<Memory>(get_engine(), current_ext_desc);
size_t elements_to_convert = prime_mem()->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
auto external_prc = current_ext_desc->getPrecision();
cpu_convert(prime_mem()->getData(), mem->getData(), internal_prc, external_prc, elements_to_convert);
return std::make_shared<Tensor>(mem);
}
}
//reorder
auto mem = std::make_shared<Memory>(get_engine(), current_ext_desc);
mem->load(*(prime_mem()));
return std::make_shared<Tensor>(mem);
}
void VariableStateDoubleBuffer::reset() {
auto new_desc = to_static(m_internal_desc);
for (auto&& mem : m_internal_mem) {
@ -100,18 +116,6 @@ void VariableStateDoubleBuffer::reset() {
}
}
MemoryDescPtr VariableStateDoubleBuffer::to_static(const MemoryDescPtr& desc) {
if (!desc->isDefined()) {
auto&& current_dims = desc->getShape().getDims();
VectorDims new_dims(current_dims.size());
std::transform(current_dims.begin(), current_dims.end(), new_dims.begin(), [](Dim x) {
return x == Shape::UNDEFINED_DIM ? 0 : x; });
return desc->cloneWithNewDims(new_dims, true);
}
return desc;
}
void VariableStateDoubleBuffer::commit() {
buffer_num ^= 0x01;
}
@ -128,5 +132,59 @@ MemoryDescPtr VariableStateDoubleBuffer::internal_desc() const {
return m_internal_desc;
}
MemoryPtr VariableStateDoubleBuffer::internal_state_mem() const {
return prime_mem();
}
VariableStateSingleBuffer::VariableStateSingleBuffer(const std::string& name,
const MemoryPtr& buffer,
const MemoryDescPtr& external_desc,
const MemoryCPtr& init_val) :
VariableStateBase(name, external_desc) {
OPENVINO_ASSERT(buffer);
m_internal_mem = buffer;
m_internal_desc = m_internal_mem->getDescPtr();
auto&& shape = m_internal_desc->getShape();
//TODO what if by some reason we already have internal static state while the node is dynamic, is it even possible?
if (shape.isStatic()) {
if (init_val) {
m_internal_mem->load(*init_val);
} else {
m_internal_mem->nullify();
}
} else {
//in the case of the original desc has dynamic shape we create an empty tensor
auto new_desc = to_static(m_internal_desc);
m_internal_mem->redefineDesc(new_desc);
}
}
void VariableStateSingleBuffer::reset() {
auto new_desc = to_static(m_internal_desc);
m_internal_mem->redefineDesc(new_desc);
m_internal_mem->nullify();
}
MemoryPtr VariableStateSingleBuffer::input_mem() {
return m_internal_mem;
}
MemoryPtr VariableStateSingleBuffer::output_mem() {
return m_internal_mem;
}
MemoryDescPtr VariableStateSingleBuffer::internal_desc() const {
return m_internal_desc;
}
MemoryPtr VariableStateSingleBuffer::internal_state_mem() const {
return m_internal_mem;
}
void VariableStateSingleBuffer::commit() {
//nothing to do
}
} // namespace intel_cpu
} // namespace ov

View File

@ -28,20 +28,35 @@ public:
virtual MemoryDescPtr internal_desc() const = 0;
};
class VariableStateDoubleBuffer : public IVariableState {
class VariableStateBase : public IVariableState {
public:
using MemBuilder = std::function<MemoryPtr(void)>;
VariableStateBase(const std::string& name, const MemoryDescPtr& external_desc);
public:
VariableStateDoubleBuffer(std::string name,
const MemBuilder& mem_build,
MemoryDescPtr external_desc,
MemoryCPtr init_val);
//ov::IVariableState
void reset() override;
void set_state(const ov::SoPtr<ov::ITensor>& state) override;
ov::SoPtr<ov::ITensor> get_state() const override;
protected:
virtual MemoryPtr internal_state_mem() const = 0;
static MemoryDescPtr to_static(const MemoryDescPtr& desc);
static const dnnl::engine& get_engine();
protected:
MemoryDescPtr m_external_desc;
};
class VariableStateDoubleBuffer : public VariableStateBase {
public:
VariableStateDoubleBuffer(const std::string& name,
const MemoryPtr& first_buffer,
const MemoryPtr& second_buffer,
const MemoryDescPtr& external_desc,
const MemoryCPtr& init_val);
//ov::IVariableState
void reset() override;
//ov::intel_cpu::IVariableState
void commit() override;
@ -50,8 +65,6 @@ public:
MemoryDescPtr internal_desc() const override;
private:
static MemoryDescPtr to_static(const MemoryDescPtr& desc);
void reset_prime_mem(const MemoryPtr& mem) {
m_internal_mem[buffer_num] = mem;
}
@ -68,16 +81,38 @@ private:
return m_internal_mem[buffer_num ^ 0x1];
}
const dnnl::engine& get_engine() const;
MemoryPtr internal_state_mem() const override;
private:
MemoryDescPtr m_external_desc;
MemoryDescPtr m_internal_desc; //mem desc required by the graph internal tensor
std::array<MemoryPtr, 2> m_internal_mem{};
size_t buffer_num = 0;
};
class VariableStateSingleBuffer : public VariableStateBase {
public:
VariableStateSingleBuffer(const std::string& name,
const MemoryPtr& buffer,
const MemoryDescPtr& external_desc,
const MemoryCPtr& init_val);
//ov::IVariableState
void reset() override;
//ov::intel_cpu::IVariableState
void commit() override;
MemoryPtr input_mem() override;
MemoryPtr output_mem() override;
MemoryDescPtr internal_desc() const override;
private:
MemoryPtr internal_state_mem() const override;
private:
MemoryDescPtr m_internal_desc; //mem desc required by the graph internal tensor
MemoryPtr m_internal_mem;
};
using MemStatePtr = std::shared_ptr<IVariableState>;
using MemStateCPtr = std::shared_ptr<const IVariableState>;
} // namespace intel_cpu

View File

@ -1673,7 +1673,7 @@ void Node::updateLastInputDims() {
}
for (size_t i = 0; i < lastInputDims.size(); i++)
lastInputDims[i] = getParentEdgesAtPort(i)[0]->getMemory().getStaticDims();
lastInputDims[i] = getParentEdgesAtPort(i)[0]->getMemory().getDesc().getShape().getDims();
}
bool Node::canFuseSimpleOperation(const NodePtr& node) const {

View File

@ -0,0 +1,39 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "arbitrary_order_desc_creator.h"
namespace ov {
namespace intel_cpu {
ArbitraryOrderDescCreator::ArbitraryOrderDescCreator(VectorDims order) :
m_order(std::move(order)) {
OPENVINO_ASSERT(std::adjacent_find(m_order.begin(), m_order.end()) == m_order.end(),
"Can't construct ArbitraryOrderDescCreator, order vector contains repetitive elements",
vec2str(m_order));
}
CpuBlockedMemoryDesc
ArbitraryOrderDescCreator::createDesc(const ov::element::Type& precision, const Shape& srcShape) const {
auto&& dims = srcShape.getDims();
OPENVINO_ASSERT(dims.size() == m_order.size(),
"Couldn't create a tensor descriptor, shape and order size mismatch. Shape: ",
vec2str(dims),
" order: ",
vec2str(m_order));
VectorDims blkDims(dims.size());
for (size_t i = 0; i < dims.size(); ++i) {
blkDims[i] = dims[m_order[i]];
}
return CpuBlockedMemoryDesc(precision, srcShape, blkDims, m_order);
}
size_t ArbitraryOrderDescCreator::getMinimalRank() const {
return m_order.size();
}
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,24 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "blocked_desc_creator.h"
namespace ov {
namespace intel_cpu {
class ArbitraryOrderDescCreator : public BlockedDescCreator {
public:
ArbitraryOrderDescCreator(VectorDims order);
CpuBlockedMemoryDesc createDesc(const ov::element::Type& precision, const Shape& srcShape) const override;
size_t getMinimalRank() const override;
private:
VectorDims m_order;
};
} // namespace intel_cpu
} // namespace ov

View File

@ -395,6 +395,10 @@ Input::Input(const Shape& shape,
const GraphContext::CPtr context)
: Node(type, name, context) {
constant = ConstantType::NoConst;
isDynamic = shape.isDynamic();
if (isDynamic) {
shapeInference = PassThroughShapeInferFactory().makeShapeInfer();
}
if (getType() == Type::Input) {
outputShapes.emplace_back(shape);
addOriginalOutputPrecision(prc);

View File

@ -11,6 +11,8 @@
#include "utils/general_utils.h"
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include "utils/ngraph_utils.hpp"
#include "shape_inference/shape_inference_pass_through.hpp"
#include "common/arbitrary_order_desc_creator.h"
using namespace dnnl;
using namespace InferenceEngine;
@ -23,9 +25,9 @@ std::mutex MemoryNodeVirtualEdge::holderMutex;
MemoryNode::MemoryNode(const std::shared_ptr<ov::Node>& op) {
if (auto assignOp = ov::as_type_ptr<ov::op::util::AssignBase>(op)) {
_id = assignOp->get_variable_id();
m_id = assignOp->get_variable_id();
} else if (auto readValueOp = ov::as_type_ptr<ov::op::util::ReadValueBase>(op)) {
_id = readValueOp->get_variable_id();
m_id = readValueOp->get_variable_id();
} else {
OPENVINO_THROW("Unexpected ov::Node type: ", op->get_type_info().name, " in MemoryNode");
}
@ -61,7 +63,7 @@ MemoryOutput::~MemoryOutput() {
MemoryNodeVirtualEdge::remove(this, holder);
}
MemoryInput& MemoryOutput::getInputNode() {
MemoryInputBase& MemoryOutput::getInputNode() {
OPENVINO_ASSERT(inputNode, "MemoryOutput ", getName(), " doesn't have sibling input");
return *inputNode;
}
@ -196,19 +198,19 @@ void MemoryOutput::executeDynamicImpl(dnnl::stream strm) {
execute(strm);
}
void MemoryOutput::registerInputNode(MemoryInput* node) {
void MemoryOutput::registerInputNode(MemoryInputBase* node) {
if (inputNode == node) { return; }
if (inputNode) { inputNode->deregisterSibling(this); }
inputNode = node;
inputNode->registerOutputNode(this);
}
void MemoryOutput::deregisterSibling(MemoryNode* node) {
void MemoryOutput::deregisterSibling(MemoryInputBase* node) {
if (node == inputNode) { inputNode = nullptr; }
}
bool MemoryInput::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
bool MemoryInputBase::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
try {
if (!one_of(op->get_type_info(),
ov::op::v3::ReadValue::get_type_info_static(),
@ -222,8 +224,8 @@ bool MemoryInput::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
return true;
}
MemoryInput::MemoryInput(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr ctx)
: Input(op, ctx), MemoryNode(op) {
MemoryInputBase::MemoryInputBase(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr ctx)
: Input(op, ctx), MemoryStateNode(op) {
std::string errorMessage;
if (!isSupportedOperation(op, errorMessage)) {
OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage);
@ -233,7 +235,32 @@ MemoryInput::MemoryInput(const std::shared_ptr<ov::Node>& op, const GraphContext
}
}
void MemoryInput::createPrimitive() {
MemoryInputBase::MemoryInputBase(const std::string id,
const std::string& name,
const std::string& type,
const Shape& output_shape,
const ov::element::Type& output_prc,
const GraphContext::CPtr context,
const ov::optional<Shape>& input_shape,
const ov::optional<ov::element::Type>& input_prc) :
Input(output_shape, output_prc, name, type, context), MemoryStateNode(id) {
outputShapes.emplace_back(output_shape);
addOriginalOutputPrecision(output_prc);
if (input_shape) {
inputShapes.push_back(*input_shape);
isDynamic = isDynamic || input_shape->isDynamic();
if (isDynamic && !shapeInference) {
shapeInference = PassThroughShapeInferFactory().makeShapeInfer();
}
}
if (input_prc) {
addOriginalInputPrecision(*input_prc);
}
// We don't need to use a virtual edge since this constructor is used in transformations and
// this is their responsibility to link the input/output nodes properly
}
void MemoryInputBase::createPrimitive() {
Input::createPrimitive();
if (!inputShapes.empty()) {
auto parentEdge = getParentEdgeAt(0);
@ -244,63 +271,7 @@ void MemoryInput::createPrimitive() {
}
}
void MemoryInput::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
auto&& shape = getOutputShapeAtPort(0);
auto precision = getOriginalOutputPrecisionAtPort(0);
auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators();
NodeConfig config;
if (!getParentEdges().empty()) {
PortConfig inPortConfig;
inPortConfig.inPlace(-1);
inPortConfig.constant(false);
inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
config.inConfs.push_back(std::move(inPortConfig));
}
PortConfig outPortConfig;
outPortConfig.inPlace(0);
outPortConfig.constant(false);
outPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
config.outConfs.push_back(std::move(outPortConfig));
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
}
void MemoryInput::initOptimalPrimitiveDescriptor() {
// Mimic the child node memory desc to avoid extra reorder
auto childEdge = getChildEdgeAt(0);
auto child = childEdge->getChild();
auto childPd = child->getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(childPd,
child->getTypeStr(), " ",
child->getName(),
"failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
const auto& childConfig = childPd->getConfig();
auto mem_desc = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc();
auto selectedPd = getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(selectedPd,
"MemoryInput ",
getName(),
" failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
auto config = selectedPd->getConfig();
config.outConfs.front().setMemDesc(mem_desc);
//bypass any checks, we enforce the child descriptor
selectedPd->setConfig(config);
}
void MemoryInput::resolveInPlaceEdges(Edge::LOOK look) {
void MemoryInputBase::resolveInPlaceEdges(Edge::LOOK look) {
if (!(look & Edge::LOOK_UP)) {
Node::resolveInPlaceEdges(look);
return;
@ -324,17 +295,17 @@ void MemoryInput::resolveInPlaceEdges(Edge::LOOK look) {
}
}
MemoryInput::~MemoryInput() {
MemoryInputBase::~MemoryInputBase() {
if (outputNode) { outputNode->deregisterSibling(this); }
MemoryNodeVirtualEdge::remove(this, holder);
}
MemoryOutput& MemoryInput::getOutputNode() {
MemoryOutput& MemoryInputBase::getOutputNode() {
OPENVINO_ASSERT(outputNode, "MemoryOutput ", getName(), " doesn't have sibling input");
return *outputNode;
}
void MemoryInput::assignState(MemStatePtr newState) {
void MemoryInputBase::assignState(MemStatePtr newState) {
assignedMem = newState->input_mem();
OPENVINO_ASSERT(assignedMem,
@ -387,40 +358,18 @@ void MemoryInput::assignState(MemStatePtr newState) {
getOutputNode().assignExtMemory(newState->output_mem(), newState->internal_desc());
}
MemStatePtr MemoryInput::makeState() const {
// assume ov::Tensor is always dense
auto original_desc =
std::make_shared<CpuBlockedMemoryDesc>(getOriginalOutputPrecisionAtPort(0), outputShapes.at(0));
auto mem_desc = getBaseMemDescAtOutputPort(0);
const auto& eng = getEngine();
auto state_name = getId();
// Remove suffix with pair ID. Internal information.
auto suffix_idx = state_name.find("/id=");
if (suffix_idx != std::string::npos) {
state_name = state_name.substr(0, suffix_idx);
}
return std::make_shared<VariableStateDoubleBuffer>(state_name,
[mem_desc, eng](){ return std::make_shared<Memory>(eng, mem_desc); },
original_desc,
getMemoryPtr());
}
void MemoryInput::registerOutputNode(MemoryOutput* node) {
void MemoryInputBase::registerOutputNode(MemoryOutput* node) {
if (outputNode == node) { return; }
if (outputNode) { outputNode->deregisterSibling(this); }
outputNode = node;
outputNode->registerInputNode(this);
}
void MemoryInput::deregisterSibling(MemoryNode* node) {
void MemoryInputBase::deregisterSibling(MemoryOutput* node) {
if (node == outputNode) { outputNode = nullptr; }
}
MemoryNodeVirtualEdge::Holder* MemoryNodeVirtualEdge::registerInput(MemoryInput * node) {
MemoryNodeVirtualEdge::Holder* MemoryNodeVirtualEdge::registerInput(MemoryInputBase * node) {
std::lock_guard<std::mutex> lock{MemoryNodeVirtualEdge::holderMutex};
// in case of output already registered
auto& holder = MemoryNodeVirtualEdge::getExisted();
@ -441,7 +390,7 @@ MemoryNodeVirtualEdge::Holder* MemoryNodeVirtualEdge::registerOutput(MemoryOutpu
auto& holder = MemoryNodeVirtualEdge::getExisted();
auto sibling = MemoryNodeVirtualEdge::getByName(holder, node->getId());
if (sibling != nullptr) {
auto inputNode = dynamic_cast<MemoryInput*>(sibling);
auto inputNode = dynamic_cast<MemoryInputBase*>(sibling);
OPENVINO_ASSERT(inputNode != nullptr);
node->registerInputNode(inputNode);
} else {
@ -458,6 +407,210 @@ void MemoryNodeVirtualEdge::remove(MemoryNode * node, Holder* holder) {
});
}
}
void MemoryInput::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
auto&& shape = getOutputShapeAtPort(0);
auto precision = getOriginalOutputPrecisionAtPort(0);
auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators();
NodeConfig config;
if (!getParentEdges().empty()) {
PortConfig inPortConfig;
inPortConfig.inPlace(-1);
inPortConfig.constant(false);
inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
config.inConfs.push_back(std::move(inPortConfig));
}
PortConfig outPortConfig;
outPortConfig.inPlace(0);
outPortConfig.constant(false);
outPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
config.outConfs.push_back(std::move(outPortConfig));
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
}
void MemoryInput::initOptimalPrimitiveDescriptor() {
// Mimic the child node memory desc to avoid extra reorder
static const Type preferredTypes[] = {
Type::ScaledDotProductAttention,
Type::MatMul,
Type::FullyConnected,
Type::Convolution,
Type::RNNCell,
Type::RNNSeq,
Type::Subgraph
};
static const Type skipTypes[] = {
Type::ShapeOf
};
auto&& childEdges = getChildEdgesAtPort(0);
EdgePtr childEdge = childEdges.front();
if (childEdges.size() > 1) {
// try to prioritize memory desc
for (auto&& item : childEdges) {
auto itemType = item->getChild()->getType();
if (std::any_of(std::begin(skipTypes), std::end(skipTypes), [=](Type type){ return type == itemType; })) {
continue;
}
if (std::any_of(std::begin(preferredTypes),
std::end(preferredTypes), [=](Type type){ return type == itemType; })) {
childEdge = item;
break;
}
}
}
auto child = childEdge->getChild();
auto childPd = child->getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(childPd,
child->getTypeStr(), " ",
child->getName(),
"failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
const auto& childConfig = childPd->getConfig();
auto mem_desc = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc();
auto selectedPd = getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(selectedPd,
"MemoryInput ",
getName(),
" failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
auto config = selectedPd->getConfig();
config.outConfs.front().setMemDesc(mem_desc);
//bypass any checks, we enforce the child descriptor
selectedPd->setConfig(config);
}
MemStatePtr MemoryInput::makeState() const {
// assume ov::Tensor is always dense
auto original_desc =
std::make_shared<CpuBlockedMemoryDesc>(getOriginalOutputPrecisionAtPort(0), outputShapes.at(0));
auto mem_desc = getBaseMemDescAtOutputPort(0);
const auto& eng = getEngine();
auto state_name = getId();
// Remove suffix with pair ID. Internal information.
auto suffix_idx = state_name.find("/id=");
if (suffix_idx != std::string::npos) {
state_name = state_name.substr(0, suffix_idx);
}
return std::make_shared<VariableStateDoubleBuffer>(state_name,
std::make_shared<Memory>(eng, mem_desc),
std::make_shared<Memory>(eng, mem_desc),
original_desc,
getMemoryPtr());
}
bool MemoryInput::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
return MemoryInputBase::isSupportedOperation(op, errorMessage);
}
void MemoryInputSDPA::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
auto&& shape = getOutputShapeAtPort(0);
auto precision = getOriginalOutputPrecisionAtPort(0);
auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators();
NodeConfig config;
if (!getParentEdges().empty()) {
PortConfig inPortConfig;
inPortConfig.inPlace(-1);
inPortConfig.constant(false);
inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
config.inConfs.push_back(std::move(inPortConfig));
}
auto&& childEdges = getChildEdgesAtPort(0);
auto itr = std::find_if(childEdges.begin(), childEdges.end(),
[](const EdgePtr& edge){ return Type::ScaledDotProductAttention == edge->getChild()->getType(); });
OPENVINO_ASSERT(itr != childEdges.end(), "MemoryInputSDPA isn't attached to an SDPA node");
auto SDPA = (*itr)->getChild();
auto childPort = (*itr)->getOutputNum();
// Since this is a very specialized implementation, lets mimic SDPA precision and set cabd layout
precision = SDPA->getOriginalInputPrecisionAtPort(childPort);
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});
PortConfig outPortConfig;
outPortConfig.inPlace(0);
outPortConfig.constant(false);
outPortConfig.setMemDesc(cabdDescCreator.createSharedDesc(precision, shape));
config.outConfs.push_back(std::move(outPortConfig));
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
}
void MemoryInputSDPA::initOptimalPrimitiveDescriptor() {
auto&& childEdges = getChildEdgesAtPort(0);
auto itr = std::find_if(childEdges.begin(), childEdges.end(),
[](const EdgePtr& edge){ return Type::ScaledDotProductAttention == edge->getChild()->getType(); });
OPENVINO_ASSERT(itr != childEdges.end(), "MemoryInputSDPA isn't attached to an SDPA node");
auto childEdge = *itr;
auto child = childEdge->getChild();
auto childPd = child->getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(childPd,
child->getTypeStr(), " ",
child->getName(),
"failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set");
const auto& childConfig = childPd->getConfig();
auto childPrecision = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc()->getPrecision();
auto selectedPd = getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(selectedPd,
"MemoryInputSDPA ",
getName(),
" failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set");
auto config = selectedPd->getConfig();
auto memDesc = config.outConfs.front().getMemDesc();
auto newMemDesc = memDesc->cloneWithNewPrecision(childPrecision);
config.outConfs.front().setMemDesc(newMemDesc);
//bypass any checks, we enforce the child descriptor precision
selectedPd->setConfig(config);
}
MemStatePtr MemoryInputSDPA::makeState() const {
// assume ov::Tensor is always dense
auto original_desc =
std::make_shared<CpuBlockedMemoryDesc>(getOriginalOutputPrecisionAtPort(0), outputShapes.at(0));
auto mem_desc = getBaseMemDescAtOutputPort(0);
const auto& eng = getEngine();
auto state_name = getId();
// Remove suffix with pair ID. Internal information.
auto suffix_idx = state_name.find("/id=");
if (suffix_idx != std::string::npos) {
state_name = state_name.substr(0, suffix_idx);
}
return std::make_shared<VariableStateSingleBuffer>(state_name,
std::make_shared<Memory>(eng, mem_desc, std::make_shared<DnnlMemoryMngr>(make_unique<MemoryMngrRealloc>())),
original_desc,
getMemoryPtr());
}
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -9,6 +9,7 @@
#include "ie_algorithm.hpp"
#include "input.h"
#include <node.h>
#include <ov_optional.hpp>
#include <memory_state.h>
#include <proxy_mem_mgr.h>
#include <string>
@ -20,20 +21,24 @@ namespace intel_cpu {
namespace node {
class MemoryOutput;
class MemoryInput;
class MemoryInputBase;
class MemoryNode { //TODO , segregate interfaces
std::string _id;
class MemoryNode {
public:
explicit MemoryNode(std::string id) : _id(id) {}
explicit MemoryNode(std::string id) : m_id(id) {}
explicit MemoryNode(const std::shared_ptr<ov::Node>& op);
virtual ~MemoryNode() = default;
std::string getId() const {
return _id;
return m_id;
}
virtual void registerInputNode(MemoryInput*) = 0;
virtual void registerOutputNode(MemoryOutput*) = 0;
virtual void deregisterSibling(MemoryNode*) = 0;
private:
std::string m_id;
};
class MemoryStateNode : public MemoryNode {
public:
using MemoryNode::MemoryNode;
virtual void assignState(MemStatePtr newState) = 0;
virtual MemStatePtr makeState() const = 0;
};
@ -60,7 +65,7 @@ public:
}
static Holder* registerOutput(MemoryOutput * node);
static Holder* registerInput(MemoryInput * node);
static Holder* registerInput(MemoryInputBase * node);
static void remove(MemoryNode * node, Holder* holder);
static std::mutex holderMutex;
};
@ -81,12 +86,8 @@ public:
}
void resolveInPlaceEdges(Edge::LOOK look) override;
void registerInputNode(MemoryInput* node) override;
void registerOutputNode(MemoryOutput* node) override {
OPENVINO_THROW("MemoryOutput node has no MemoryOutput type sibling!");
}
void deregisterSibling(MemoryNode* node) override;
void registerInputNode(MemoryInputBase* node);
void deregisterSibling(MemoryInputBase* node);
bool needShapeInfer() const override { return false; }
bool needPrepareParams() const override { return false; }
@ -94,38 +95,38 @@ public:
void assignExtMemory(const MemoryPtr& mem, const MemoryDescPtr& memDesc);
private:
MemoryInput& getInputNode();
void assignState(MemStatePtr newState) override {
OPENVINO_THROW("Unexpected MemoryOutput::assignState call"); //TODO , segregate interfaces
}
MemStatePtr makeState() const override {
OPENVINO_THROW("Unexpected MemoryOutput::makeState call"); //TODO , segregate interfaces
}
MemoryInputBase& getInputNode();
private:
/**
* @brief keeps reference to input sibling node
*/
MemoryInput* inputNode = nullptr;
MemoryInputBase* inputNode = nullptr;
MemoryPtr assignedMem = nullptr;
MemoryDescPtr extMemDesc = nullptr; // used for resize
MemoryNodeVirtualEdge::Holder* holder = nullptr;
ProxyMemoryMngrPtr memMngr = nullptr;
};
class MemoryInput : public Input, public MemoryNode {
class MemoryInputBase : public Input, public MemoryStateNode {
public:
MemoryInput(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);
~MemoryInput() override;
MemoryInputBase(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);
MemoryInputBase(const std::string id,
const std::string& name,
const std::string& type,
const Shape& output_shape,
const ov::element::Type& output_prc,
const GraphContext::CPtr context,
const ov::optional<Shape>& input_shape,
const ov::optional<ov::element::Type>& input_prc);
~MemoryInputBase() override;
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
bool created() const override {
return getType() == Type::MemoryInput;
}
void initSupportedPrimitiveDescriptors() override;
void initOptimalPrimitiveDescriptor() override;
void execute(dnnl::stream strm) override {/*pass*/}
void executeDynamicImpl(dnnl::stream strm) override {/*pass*/}
@ -133,17 +134,11 @@ public:
void resolveInPlaceEdges(Edge::LOOK look) override;
void registerInputNode(MemoryInput* node) override {
OPENVINO_THROW("MemoryInput node has no MemoryInput type sibling!");
}
void registerOutputNode(MemoryOutput* node) override;
void deregisterSibling(MemoryNode* node) override;
void registerOutputNode(MemoryOutput* node);
void deregisterSibling(MemoryOutput* node);
// May be extracted to some interface when necessary
void assignState(MemStatePtr newState) override;
MemStatePtr makeState() const override;
private:
MemoryOutput& getOutputNode();
private:
@ -156,6 +151,27 @@ private:
ProxyMemoryMngrPtr memMngr = nullptr;
};
class MemoryInput : public MemoryInputBase {
public:
using MemoryInputBase::MemoryInputBase;
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
void initSupportedPrimitiveDescriptors() override;
void initOptimalPrimitiveDescriptor() override;
MemStatePtr makeState() const override;
};
class MemoryInputSDPA : public MemoryInputBase {
public:
using MemoryInputBase::MemoryInputBase;
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
void initSupportedPrimitiveDescriptors() override;
void initOptimalPrimitiveDescriptor() override;
MemStatePtr makeState() const override;
};
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -15,11 +15,12 @@
#include <shape_inference/shape_inference_internal_dyn.hpp>
#include <vector>
#include "common/cpu_memcpy.h"
#include "openvino/core/parallel.hpp"
#include "memory_desc/cpu_memory_desc_utils.h"
#include "memory_desc/dnnl_blocked_memory_desc.h"
#include "utils/plain_tensor.hpp"
#include <openvino/op/scaled_dot_product_attention.hpp>
#include "common/arbitrary_order_desc_creator.h"
#ifdef OV_CPU_WITH_MLAS
# include "mlas/sgemm.hpp"
@ -576,48 +577,93 @@ struct MHASingleToken {
template <ScaledDotProductAttention::KernelTypes KType, typename T>
struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAttention::Executor {
PlainTensor<T> q_input; // f32[B, L1, H*S] / [B, H, L1, S]
PlainTensor<T> k_input; // f32[B, L1, H*S]
PlainTensor<T> v_input; // f32[B, L1, H*S]
PlainTensor<T> k_cache; // f32[B, H, max_kvLen, S]
PlainTensor<T> v_cache; // f32[B, H, max_kvLen, S]
PlainTensor<T> q_input; // f32[B, H, L1, S]
PlainTensor<T> k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S]
PlainTensor<T> v_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S]
PlainTensor<int32_t> beam_table; // i32[B, max_kvLen]
PlainTensor<float> attn_mask; // f32[B, qLen + kvLen]
float scale_input = 0.0f; // f32[B, qLen + kvLen]
PlainTensor<float> cos_tab; // f32[max_kv_len, rotary_dims//2]
PlainTensor<float> sin_tab; // f32[max_kv_len, rotary_dims//2]
PlainTensor<T> output_emb; // f32[B, L1, H*S]
PlainTensor<float> attn_buf; // f32[[B|1],[H|1], L1|1, L0+L1]
float scale_input = 0.0f;
MHAKernel<KType, T> kernel;
MHASingleToken<T> kernel_single_token;
PlainTensor<T> m_query_emb; // query with RoPE position embedding
size_t B, H, L1, L0, S;
ScaledDotProductAttention::Config config;
AttentionExecutor(const ScaledDotProductAttention::Config& _config) : config(_config) {}
Config config;
AttentionExecutor(const Config& _config) : attn_buf(true), config(_config) {}
void prepare_attn_mask(MemoryPtr attn_input) {
attn_mask.resize(attn_input->getStaticDims());
attn_buf.resize(attn_input->getStaticDims());
auto p = reinterpret_cast<uint8_t*>(attn_input->getData());
for (size_t i = 0; i < attn_input->getSize(); i++)
attn_mask.data()[i] = p[i] ? 0.0f : -FLT_MAX;
attn_buf.data()[i] = p[i] ? 0.0f : -FLT_MAX;
}
void concat_pastkv(const std::vector<MemoryPtr>& inputs,
const std::vector<MemoryPtr>& outputs,
const PlainTensor<T>& k_input,
const PlainTensor<T>& v_input,
PlainTensor<T>& past_k_output,
PlainTensor<T>& past_v_output) {
if (config.config.fuse_concat) {
k_input.assert_dims({B, 0, L1, S}, true);
v_input.assert_dims({B, 0, L1, S}, true);
auto past_k_idx = inputs.size() - 2;
auto past_k_mem = inputs[past_k_idx + 0];
L0 = past_k_mem->getStaticDims()[2];
// k,v may support multiquery
auto Hk = past_k_mem->getStaticDims()[1];
// [B, H, L0, S]
past_k_output.reset(outputs[1]);
past_v_output.reset(outputs[2]);
parallel_for3d(B, Hk, L1, [&](size_t b, size_t h, size_t m) {
std::memcpy(&past_k_output.at({b, h, m + L0, 0}),
&k_input.at({b, h, m, 0}),
S * sizeof(T));
std::memcpy(&past_v_output.at({b, h, m + L0, 0}),
&v_input.at({b, h, m, 0}),
S * sizeof(T));
});
if (!config.is_concat_inplaced) {
PlainTensor<T> past_k_input, past_v_input;
past_k_input.reset(past_k_mem);
past_v_input.reset(inputs[past_k_idx + 1]);
parallel_for3d(B, Hk, L0, [&](size_t b, size_t h, size_t m) {
std::memcpy(&past_k_output.at({b, h, m, 0}),
&past_k_input.at({b, h, m, 0}),
S * sizeof(T));
std::memcpy(&past_v_output.at({b, h, m, 0}),
&past_v_input.at({b, h, m, 0}),
S * sizeof(T));
});
}
} else {
// k,v inputs are already concatenated
L0 = k_input.size(2) - L1;
k_input.assert_dims({B, 0, L0 + L1, S}, true);
v_input.assert_dims({B, 0, L0 + L1, S}, true);
past_k_output = k_input;
past_v_output = v_input;
}
}
void execute(dnnl::stream strm, const std::vector<MemoryPtr>& inputs, const std::vector<MemoryPtr>& outputs) override {
bool has_out_transpose = config.output_BLHxS;
bool fuse_causal_attn = config.fuse_causal_attn;
bool is_causal = config.is_causal;
auto input_num = inputs.size();
bool has_out_transpose = config.config.output_BLHxS;
bool fuse_causal_attn = config.config.fuse_causal_attn;
bool is_causal = config.config.is_causal;
const bool fuse_concat = config.config.fuse_concat;
auto input_num = inputs.size() - (fuse_concat ? 2 : 0);
q_input.reset(inputs[0]);
k_input.reset(inputs[1]);
v_input.reset(inputs[2]);
PlainTensor<float> attn_mask;
if (input_num > 3) {
// attn_mask
if (inputs[3]->getDesc().getPrecision() == ov::element::u8) {
// bool->f32
prepare_attn_mask(inputs[3]);
attn_mask = attn_buf;
} else {
attn_mask.reset(inputs[3]);
}
@ -627,30 +673,22 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
}
}
size_t B, H, L1, L0, S;
// q, k, v: [B, H, L1, S]
// q: [B, H, L1, S]
B = q_input.size(0);
H = q_input.size(1);
L1 = q_input.size(2);
L0 = k_input.size(2) - L1;
S = q_input.size(-1);
ov::intel_cpu::PlainTensor<T> output_emb(outputs[0]);
PlainTensor<T> present_key, present_value;
concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value);
q_input.assert_dims({B, H, L1, S});
k_input.assert_dims({B, 0, L0 + L1, S}, true);
v_input.assert_dims({B, 0, L0 + L1, S}, true);
m_query_emb = q_input;
present_key = k_input;
present_value = v_input;
ov::intel_cpu::PlainTensor<T> output_emb(outputs[0]);
bool auto_causal;
bool use_attn_mask;
if (fuse_causal_attn) {
assert(attn_mask);
attn_mask.assert_dims({B, 1, 1, L0 + L1});
attn_mask.assert_dims({B, 1, L1, L0 + L1});
auto_causal = true;
use_attn_mask = true;
} else {
@ -677,7 +715,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
if (L1 > 1) {
// multi-token version
kernel(strm, m_query_emb, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
kernel(strm, q_input, k_input, v_input, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
output_emb, has_out_transpose, auto_causal, scale_input);
} else {
// 1-token version
@ -685,7 +723,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
// 1, in matrix mutiply, using AMX is not efficency because the M dimension of A will alway be 1
// 2, using float will save the repack cost which typically is required for bf16/int8 opt
// 3, using dot product can leverage the SIMD while easily adapt to indirect kv cache
kernel_single_token(m_query_emb, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
kernel_single_token(q_input, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
output_emb, beam_table, has_out_transpose, auto_causal, scale_input);
}
}
@ -700,9 +738,10 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr<ngrap
const auto node = std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op);
if (node) {
m_config.is_causal = node->get_causal();
m_config.config.is_causal = node->get_causal();
} else {
OPENVINO_THROW("CPU: cast to v13::ScaledDotProductAttention failed.");
const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionStub>(op);
m_config.config = node->get_config();
}
}
@ -711,6 +750,80 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
return;
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
NodeConfig config;
auto& creatorsMap = BlockedDescCreator::getCommonCreators();
auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 2 : 0);
config.inConfs.resize(getOriginalInputsNumber());
config.outConfs.resize(getOriginalOutputsNumber());
config.inConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
rtPrecision, getInputShapeAtPort(0)));
config.inConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
rtPrecision, getInputShapeAtPort(1)));
config.inConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
rtPrecision, getInputShapeAtPort(2)));
auto nextPortIdx = 3;
if (orginSDPInputNumber > 3) {
// attn_mask
if (getOriginalInputPrecisionAtPort(nextPortIdx) == ov::element::u8) {
config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
ov::element::u8, getInputShapeAtPort(nextPortIdx)));
} else {
config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
ov::element::f32, getInputShapeAtPort(nextPortIdx)));
}
nextPortIdx++;
}
if (orginSDPInputNumber > 4) {
config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
ov::element::f32, getInputShapeAtPort(nextPortIdx)));
}
if (m_config.config.fuse_concat) {
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});
config.inConfs[orginSDPInputNumber + 0].setMemDesc(cabdDescCreator.createSharedDesc(
rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 0)));
config.inConfs[orginSDPInputNumber + 1].setMemDesc(cabdDescCreator.createSharedDesc(
rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 1)));
config.outConfs[1].setMemDesc(cabdDescCreator.createSharedDesc(
rtPrecision, getOutputShapeAtPort(1)));
config.outConfs[1].inPlace(orginSDPInputNumber + 0);
config.outConfs[2].setMemDesc(cabdDescCreator.createSharedDesc(
rtPrecision, getOutputShapeAtPort(2)));
config.outConfs[2].inPlace(orginSDPInputNumber + 1);
}
config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
rtPrecision, getOutputShapeAtPort(0)));
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any);
// may fallback to abcd without inplace
if (m_config.config.fuse_concat) {
config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 0)));
config.inConfs[orginSDPInputNumber + 1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 1)));
config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
rtPrecision, getOutputShapeAtPort(1)));
config.outConfs[1].inPlace(-1);
config.outConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
rtPrecision, getOutputShapeAtPort(2)));
config.outConfs[2].inPlace(-1);
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any);
}
}
void ScaledDotProductAttention::createPrimitive() {
if (m_config.config.fuse_concat) {
auto desc = getSelectedPrimitiveDescriptor();
if (desc == nullptr)
OPENVINO_THROW("has unidentified preferable primitive descriptor");
m_config.is_concat_inplaced = desc->getConfig().outConfs[1].inPlace() >= 0;
}
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
if (rtPrecision == ov::element::bf16) {
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(m_config);
} else {
@ -722,29 +835,6 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, float>>(m_config);
#endif
}
// initialize input ports
std::vector<PortConfigurator> inPortConfigs;
inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(0), false, -1);
inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(1), false, -1);
inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(2), false, -1);
if (getOriginalInputsNumber() > 3) {
// attn_mask
if (getOriginalInputPrecisionAtPort(3) == ov::element::u8) {
inPortConfigs.emplace_back(LayoutType::ncsp, ov::element::u8, getInputShapeAtPort(3), false, -1);
} else {
inPortConfigs.emplace_back(LayoutType::ncsp, ov::element::f32, getInputShapeAtPort(3), false, -1);
}
}
if (getOriginalInputsNumber() > 4) {
inPortConfigs.emplace_back(LayoutType::ncsp, ov::element::f32, getInputShapeAtPort(4), false, -1);
}
// initialize output port
std::vector<PortConfigurator> outPortConfigs;
outPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getOutputShapeAtPort(0), false, -1);
addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any);
}
void ScaledDotProductAttention::execute(dnnl::stream strm) {
@ -760,8 +850,9 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) {
bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
try {
if (!std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op)) {
errorMessage = "Only ScaledDotProductAttention operation are supported";
if (!std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op) &&
!std::dynamic_pointer_cast<const ScaledDotProductAttentionStub>(op)) {
errorMessage = "Only ScaledDotProductAttention or ScaledDotProductAttentionStub operation are supported";
return false;
}
// expect shape of q: [B, H, L, S]
@ -770,7 +861,14 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
return false;
}
if (op->get_input_size() > 3) {
int orgSDPAInput = static_cast<int>(op->get_input_size());
const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionStub>(op);
if (node) {
if (node->get_config().fuse_concat) {
orgSDPAInput -= 2;
}
}
if (orgSDPAInput > 3) {
inRank = op->get_input_partial_shape(3).size();
if (inRank > 4u) {
errorMessage = "Doesn't support 'attention mask' with rank: " + std::to_string(inRank);

View File

@ -10,6 +10,8 @@
#include <string>
#include <vector>
#include "transformations/cpu_opset/common/op/sdp.hpp"
namespace ov {
namespace intel_cpu {
namespace node {
@ -22,6 +24,10 @@ public:
bool created() const override {
return getType() == Type::ScaledDotProductAttention;
}
// pastkv may have zero dimension
bool isExecutable() const override {
return true;
}
bool needPrepareParams() const override {
return false;
}
@ -30,6 +36,7 @@ public:
}
void initSupportedPrimitiveDescriptors() override;
void execute(dnnl::stream strm) override;
void createPrimitive() override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
enum KernelTypes { KT_REF, KT_ONEDNN, KT_MLAS};
@ -40,9 +47,8 @@ private:
};
struct Config {
bool output_BLHxS = false;
bool fuse_causal_attn = false;
bool is_causal = false;
ScaledDotProductAttentionStub::Config config;
bool is_concat_inplaced = false;
};
Config m_config;

View File

@ -51,12 +51,34 @@ void ShapeOf::initSupportedPrimitiveDescriptors() {
ov::element::Type precision = getOriginalInputPrecisionAtPort(0);
const LayoutType dataFormats[4] = { LayoutType::ncsp, LayoutType::nspc, LayoutType::nCsp16c, LayoutType::nCsp8c };
for (const auto &df : dataFormats) {
addSupportedPrimDesc({{df, precision}},
{{LayoutType::ncsp, ov::element::i32}},
impl_desc_type::ref);
}
addSupportedPrimDesc({{LayoutType::ncsp, precision}},
{{LayoutType::ncsp, ov::element::i32}},
impl_desc_type::ref);
}
void ShapeOf::initOptimalPrimitiveDescriptor() {
// Mimic the parent node memory desc to avoid extra reorder
auto parentEdge = getParentEdgeAt(0);
auto parent = parentEdge->getParent();
auto parentPd = parent->getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(parentPd,
parent->getTypeStr(), " ",
parent->getName(),
"failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
const auto& parentConfig = parentPd->getConfig();
auto mem_desc = parentConfig.outConfs[parentEdge->getInputNum()].getMemDesc();
auto selected_pd = getSelectedPrimitiveDescriptor();
OPENVINO_ASSERT(selected_pd,
"ShapeOf ",
getName(),
" failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
auto config = selected_pd->getConfig();
config.inConfs.front().setMemDesc(mem_desc);
//bypass any checks, we enforce the parent descriptor
selected_pd->setConfig(config);
}
bool ShapeOf::isExecutable() const {
@ -66,12 +88,12 @@ bool ShapeOf::isExecutable() const {
void ShapeOf::execute(dnnl::stream strm) {
auto inPtr = getParentEdgeAt(0)->getMemoryPtr();
auto outPtr = getChildEdgeAt(0)->getMemoryPtr();
auto inDims = inPtr->getStaticDims();
auto&& inDims = inPtr->getStaticDims();
size_t dimsCount = inDims.size();
if (outPtr->getStaticDims().size() != 1 || dimsCount != outPtr->getStaticDims()[0])
OPENVINO_THROW(errorPrefix, "has inconsistent input shape and output size");
auto *dst = reinterpret_cast<int *>(getChildEdgeAt(0)->getMemoryPtr()->getData());
auto* dst = reinterpret_cast<int *>(outPtr->getData());
for (size_t i = 0; i < dimsCount; i++) {
dst[i] = inDims[i];

View File

@ -20,6 +20,7 @@ public:
void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override;
void initOptimalPrimitiveDescriptor() override;
void execute(dnnl::stream strm) override;
bool created() const override;
bool needPrepareParams() const override {return false;};

View File

@ -0,0 +1,57 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "sdp.hpp"
#include <algorithm>
#include "transformations/itt.hpp"
ov::intel_cpu::ScaledDotProductAttentionStub::ScaledDotProductAttentionStub(const OutputVector& args, const Config& cfg)
: Op(args),
m_config(cfg) {
constructor_validate_and_infer_types();
}
std::shared_ptr<ov::Node> ov::intel_cpu::ScaledDotProductAttentionStub::clone_with_new_inputs(
const ov::OutputVector& new_args) const {
INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<ov::intel_cpu::ScaledDotProductAttentionStub>(new_args, m_config);
}
void ov::intel_cpu::ScaledDotProductAttentionStub::validate_and_infer_types() {
INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_validate_and_infer_types);
auto input_num = get_input_size();
// [B, H, L1, S]
auto q_ps = get_input_partial_shape(0);
// [B, H, L0, S]
auto past_kv_ps = get_input_partial_shape(input_num - 1);
NODE_VALIDATION_CHECK(this, m_config.output_BLHxS == false);
NODE_VALIDATION_CHECK(this, q_ps.size() >= 3);
if (past_kv_ps.rank().is_static()) {
NODE_VALIDATION_CHECK(this, q_ps.size() == past_kv_ps.size());
for (size_t i = 0; i < q_ps.size(); i++) {
if (i == q_ps.size() - 2)
continue;
NODE_VALIDATION_CHECK(this, q_ps[i].compatible(past_kv_ps[i]));
}
past_kv_ps[q_ps.size() - 2] += q_ps[q_ps.size() - 2];
}
set_output_type(0, get_input_element_type(0), q_ps);
set_output_type(1, get_input_element_type(input_num - 1), past_kv_ps);
set_output_type(2, get_input_element_type(input_num - 1), past_kv_ps);
}
bool ov::intel_cpu::ScaledDotProductAttentionStub::visit_attributes(ov::AttributeVisitor& visitor) {
INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_visit_attributes);
visitor.start_structure("config");
visitor.on_attribute("output_BLHxS", m_config.output_BLHxS);
visitor.on_attribute("fuse_causal_attn", m_config.fuse_causal_attn);
visitor.on_attribute("is_causal", m_config.is_causal);
visitor.on_attribute("fuse_concat", m_config.fuse_concat);
visitor.finish_structure();
return true;
}

View File

@ -0,0 +1,50 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include "openvino/op/op.hpp"
namespace ov {
namespace intel_cpu {
/// \brief Scaled dot product attention from PyTorch, fused with Concat
///
/// \ingroup ov_ops_cpp_api
class ScaledDotProductAttentionStub : public ov::op::Op {
public:
OPENVINO_OP("ScaledDotProductAttentionStub", "cpu_plugin_opset");
ScaledDotProductAttentionStub() = default;
struct Config {
bool output_BLHxS = false; // true implies that output is [B,L,H*S]
bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask
bool is_causal = false; // apply causal mask internally
bool fuse_concat = false; // fuse (concat->sdp) ==> sdp
};
ScaledDotProductAttentionStub(const OutputVector& args, const Config& cfg);
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
void validate_and_infer_types() override;
const Config& get_config() const {
return m_config;
}
Config& get_config() {
return m_config;
}
private:
Config m_config;
};
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,121 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "stateful_sdp_fusion.hpp"
#include <cstdint>
#include <limits>
#include <openvino/core/rt_info.hpp>
#include <openvino/opsets/opset1.hpp>
#include <openvino/opsets/opset13.hpp>
#include <openvino/opsets/opset6.hpp>
#include <openvino/opsets/opset8.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <transformations/utils/utils.hpp>
#include "itt.hpp"
#include "ov_ops/type_relaxed.hpp"
#include "transformations/cpu_opset/common/op/sdp.hpp"
namespace ov {
namespace intel_cpu {
StatefulSDPFusion::StatefulSDPFusion() {
MATCHER_SCOPE(StatefulSDPFusion);
using namespace ov::pass::pattern;
auto past_k = wrap_type<opset6::ReadValue>();
auto past_v = wrap_type<opset6::ReadValue>();
auto convert_past_k = wrap_type<opset1::Convert>({past_k});
auto convert_past_v = wrap_type<opset1::Convert>({past_v});
auto concat_input_k = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past_k, convert_past_k});
auto concat_input_v = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past_v, convert_past_v});
auto concat_k = wrap_type<opset6::Concat>({concat_input_k, any_input()});
auto concat_v = wrap_type<opset6::Concat>({concat_input_v, any_input()});
auto sdp0 = wrap_type<opset13::ScaledDotProductAttention>({any_input(), concat_k, concat_v});
auto sdp1 = wrap_type<opset13::ScaledDotProductAttention>({any_input(), concat_k, concat_v, any_input()});
auto sdp2 = wrap_type<opset13::ScaledDotProductAttention>({any_input(), concat_k, concat_v, any_input(), any_input()});
auto sdp = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{sdp0, sdp1, sdp2});
ov::matcher_pass_callback callback = [=](Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto root = m.get_match_root();
auto find_assign = [&](const ov::Output<ov::Node>& out, opset6::Assign*& assign, opset1::Convert*& cvt) {
auto present_to = out.get_target_inputs();
if (present_to.size() != 2)
return;
for (auto& to : present_to) {
auto to_node = to.get_node();
if (auto convert = dynamic_cast<opset1::Convert*>(to_node)) {
auto cvt_targets = convert->get_output_target_inputs(0);
if (cvt_targets.size() == 1) {
to_node = cvt_targets.begin()->get_node();
cvt = convert;
}
}
assign = dynamic_cast<opset6::Assign*>(to_node);
if (assign)
return;
}
};
std::shared_ptr<opset1::Convert> read_cvt_k_node, read_cvt_v_node;
const auto sdp_node = ov::as_type_ptr<opset13::ScaledDotProductAttention>(root);
const auto past_k_node = ov::as_type_ptr<opset6::ReadValue>(pattern_map.at(past_k).get_node_shared_ptr());
const auto past_v_node = ov::as_type_ptr<opset6::ReadValue>(pattern_map.at(past_v).get_node_shared_ptr());
const auto concat_k_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_k).get_node_shared_ptr());
const auto concat_v_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_v).get_node_shared_ptr());
if (pattern_map.count(convert_past_k)) {
read_cvt_k_node = ov::as_type_ptr<opset1::Convert>(pattern_map.at(convert_past_k).get_node_shared_ptr());
read_cvt_v_node = ov::as_type_ptr<opset1::Convert>(pattern_map.at(convert_past_v).get_node_shared_ptr());
}
opset6::Assign* assign_k_node = nullptr, *assign_v_node = nullptr;
opset1::Convert* assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr;
find_assign(concat_k_node, assign_k_node, assign_cvt_k_node);
if (!assign_k_node)
return false;
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id())
return false;
find_assign(concat_v_node, assign_v_node, assign_cvt_v_node);
if (!assign_v_node)
return false;
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id())
return false;
auto args = sdp_node->input_values();
args[1] = concat_k_node->input_value(1);
args[2] = concat_v_node->input_value(1);
args.push_back(read_cvt_k_node ? read_cvt_k_node->output(0) : past_k_node->output(0));
args.push_back(read_cvt_v_node ? read_cvt_v_node->output(0) : past_v_node->output(0));
ov::intel_cpu::ScaledDotProductAttentionStub::Config config;
config.is_causal = sdp_node->get_causal();
config.fuse_concat = true;
auto old_node = sdp_node;
auto new_node = std::make_shared<ov::intel_cpu::ScaledDotProductAttentionStub>(args, config);
new_node->set_friendly_name(old_node->get_friendly_name());
ov::replace_node(old_node, {new_node->output(0)});
if (assign_cvt_k_node)
assign_cvt_k_node->set_arguments({new_node->output(1)});
else
assign_k_node->set_arguments({new_node->output(1)});
if (assign_cvt_v_node)
assign_cvt_v_node->set_arguments({new_node->output(2)});
else
assign_v_node->set_arguments({new_node->output(2)});
return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(sdp, matcher_name);
this->register_matcher(m, callback);
}
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,18 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/pass/graph_rewrite.hpp>
namespace ov {
namespace intel_cpu {
class StatefulSDPFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("StatefulSDPFusion", "0");
StatefulSDPFusion();
};
} // namespace intel_cpu
} // namespace ov

View File

@ -112,6 +112,7 @@
#include "transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp"
#include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp"
#include "transformations/cpu_opset/common/pass/rope_fusion.hpp"
#include "transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp"
// Snippets
#include "snippets/pass/tokenization.hpp"
@ -660,6 +661,7 @@ void Transformations::PostLpt() {
CPU_REGISTER_PASS_X64(postLPTPassManager, EliminateStridedSlice);
CPU_REGISTER_PASS_X64(postLPTPassManager, RoPEFusion);
CPU_REGISTER_PASS_X64(postLPTPassManager, StatefulSDPFusion);
postLPTPassManager.run_passes(model);
}

View File

@ -169,10 +169,19 @@ struct PlainTensor {
}
void reset(MemoryPtr mem) {
assert_dt<DT>(mem->getDesc().getPrecision());
const auto& mem_desc = mem->getDesc();
assert_dt<DT>(mem_desc.getPrecision());
const auto* desc_ptr = mem_desc.as<BlockedMemoryDesc>();
// not support block layout
OPENVINO_ASSERT(desc_ptr && desc_ptr->getOrder().size() == mem->getStaticDims().size());
m_mem = mem;
VectorDims strides(desc_ptr->getStrides().size());
const auto& orders = desc_ptr->getOrder();
for (size_t i = 0; i < orders.size(); i++) {
strides[orders[i]] = desc_ptr->getStrides()[i];
}
// this reshape_to() can do reshape w/o additional cost
resize(mem->getStaticDims(), reinterpret_cast<DT*>(mem->getData()));
resize(mem->getStaticDims(), reinterpret_cast<DT*>(mem->getData()), &strides);
}
ov::element::Type get_precision(void) {
@ -327,14 +336,14 @@ struct PlainTensor {
return new_tensor_view;
}
void resize(const VectorDims& new_dims, DT* data = nullptr) {
void resize(const VectorDims& new_dims, DT* data = nullptr, const VectorDims* strides = nullptr) {
// initialize strides for compact/dense tensor
m_rank = new_dims.size();
assert(m_rank <= PLAINTENSOR_RANK_MAX);
size_t stride = 1;
for (int i = m_rank - 1; i >= 0; i--) {
m_dims[i] = new_dims[i];
m_strides[i] = stride;
m_strides[i] = strides ? (*strides)[i] : stride;
stride *= new_dims[i];
}

View File

@ -72,7 +72,7 @@ protected:
for (auto&& shape : inputDynamicShapes)
params.push_back(std::make_shared<ov::op::v0::Parameter>(inType, shape));
auto shapeOf = std::make_shared<ngraph::opset3::ShapeOf>(params[0], ngraph::element::i32);
auto shapeOf = std::make_shared<ngraph::opset3::ShapeOf>(params.front(), ngraph::element::i32);
function = makeNgraphFunction(netPrecision, params, shapeOf, "ShapeOf");
}
@ -85,29 +85,6 @@ TEST_P(ShapeOfLayerCPUTest, CompareWithRefs) {
namespace {
/* CPU PARAMS */
std::vector<CPUSpecificParams> getCpuInfoForDimsCount(const size_t dimsCount = 3) {
std::vector<CPUSpecificParams> resCPUParams;
if (dimsCount == 5) {
resCPUParams.push_back(CPUSpecificParams{{nCdhw16c}, {x}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{nCdhw8c}, {x}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{ncdhw}, {x}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{ndhwc}, {x}, {}, {}});
} else if (dimsCount == 4) {
resCPUParams.push_back(CPUSpecificParams{{nChw16c}, {x}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{nChw8c}, {x}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{nchw}, {x}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{nhwc}, {x}, {}, {}});
} else {
resCPUParams.push_back(CPUSpecificParams{{nCw16c}, {x}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{nCw8c}, {x}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{abc}, {x}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{acb}, {x}, {}, {}});
}
return resCPUParams;
}
const std::vector<ElementType> netPrecisions = {
ElementType::f32,
ElementType::bf16,
@ -119,17 +96,9 @@ std::vector<ov::test::InputShape> inShapesDynamic3d = {
{
{-1, -1, -1},
{
{ 8, 5, 4 },
{ 8, 5, 3 },
{ 8, 5, 2 }
}
},
{
{-1, -1, -1},
{
{ 1, 2, 4 },
{ 1, 2, 3 },
{ 1, 2, 2 }
{ 8, 16, 4 },
{ 8, 16, 3 },
{ 8, 16, 2 }
}
}
};
@ -138,36 +107,20 @@ std::vector<ov::test::InputShape> inShapesDynamic4d = {
{
{-1, -1, -1, -1},
{
{ 8, 5, 3, 4 },
{ 8, 5, 3, 3 },
{ 8, 5, 3, 2 }
{ 8, 16, 3, 4 },
{ 8, 16, 3, 3 },
{ 8, 16, 3, 2 }
}
},
{
{-1, -1, -1, -1},
{
{ 1, 2, 3, 4 },
{ 1, 2, 3, 3 },
{ 1, 2, 3, 2 }
}
}
};
std::vector<ov::test::InputShape> inShapesDynamic5d = {
{
{ -1, -1, -1, -1, -1 },
{
{ 8, 5, 3, 2, 4 },
{ 8, 5, 3, 2, 3 },
{ 8, 5, 3, 2, 2 }
}
},
{
{-1, -1, -1, -1, -1},
{
{ 1, 2, 3, 4, 4 },
{ 1, 2, 3, 4, 3 },
{ 1, 2, 3, 4, 2 }
{ 8, 16, 3, 2, 4 },
{ 8, 16, 3, 2, 3 },
{ 8, 16, 3, 2, 2 }
}
}
};
@ -175,19 +128,19 @@ const auto params5dDynamic = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inShapesDynamic5d),
::testing::ValuesIn(netPrecisions)),
::testing::ValuesIn(getCpuInfoForDimsCount(5)));
::testing::Values(emptyCPUSpec));
const auto params4dDynamic = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inShapesDynamic4d),
::testing::ValuesIn(netPrecisions)),
::testing::ValuesIn(getCpuInfoForDimsCount(4)));
::testing::Values(emptyCPUSpec));
const auto params3dDynamic = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inShapesDynamic3d),
::testing::ValuesIn(netPrecisions)),
::testing::ValuesIn(getCpuInfoForDimsCount(3)));
::testing::Values(emptyCPUSpec));
// We don't check static case, because of constant folding
INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf3dDynamicLayoutTest, ShapeOfLayerCPUTest,

View File

@ -0,0 +1,228 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <openvino/opsets/opset13.hpp>
#include <transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp>
#include "ov_models/builders.hpp"
#include "ov_models/utils/ov_helpers.hpp"
#include "shared_test_classes/base/layer_test_utils.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "test_utils/cpu_test_utils.hpp"
#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp"
using namespace ov::test;
using namespace ngraph;
using namespace CPUTestUtils;
using namespace InferenceEngine;
namespace SubgraphTestsDefinitions {
using ConcatSDPTestParams = std::tuple<ElementType,
std::vector<InputShape>,
bool // has ShapeOf
>;
// Subgraph:
/* Parameter
* |
* Parameter ReadValue | ReadValue Parameter
* \ / | \ /
* \ / | \ /
* Concat | Concat
* / \ | / \
* / \ | / \
* / \ | / \
* Assign ScaledDotProductAttention Assign
* |
* Add
* |
* Result
*/
class ConcatSDPTest : public testing::WithParamInterface<ConcatSDPTestParams>, virtual public ov::test::SubgraphBaseTest, public CPUTestsBase {
public:
static std::string getTestCaseName(const testing::TestParamInfo<ConcatSDPTestParams>& obj) {
ElementType inType;
std::vector<InputShape> inputShapes;
bool hasShapeof;
std::tie(inType, inputShapes, hasShapeof) = obj.param;
std::ostringstream result;
result << "IS=";
for (const auto& shape : inputShapes) {
result << ov::test::utils::partialShape2str({shape.first}) << "_";
}
result << "TS=";
for (const auto& shape : inputShapes) {
result << "(";
if (!shape.second.empty()) {
for (const auto& itr : shape.second) {
result << ov::test::utils::vec2str(itr);
}
}
result << ")_";
}
result << "Prc=" << inType << "_";
result << "HasShapeOf=" << hasShapeof;
return result.str();
}
void SetUp() override {
ElementType inType;
std::vector<InputShape> inputShapes;
bool hasShapeOf;
std::tie(inType, inputShapes, hasShapeOf) = this->GetParam();
targetDevice = ov::test::utils::DEVICE_CPU;
rel_threshold = 1e-4f;
if (inType == ElementType::bf16) {
configuration.insert({"ENFORCE_BF16", "YES"});
rel_threshold = 0.01f;
}
init_input_shapes(inputShapes);
ov::ParameterVector inputParams;
// q,k,v
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
inputParams[0]->set_friendly_name("q");
inputParams[1]->set_friendly_name("k");
inputParams[2]->set_friendly_name("v");
// pastkv init_cost
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
auto var_k = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastk"});
auto pastk = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_k);
pastk->set_friendly_name("pastk_r");
auto var_v = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastv"});
auto pastv = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_v);
pastv->set_friendly_name("pastv_r");
std::shared_ptr<Node> pastk_shapeof, pastv_shapeof;
if (hasShapeOf) {
pastk_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastk);
pastv_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastv);
}
auto concatK = std::make_shared<ov::op::v0::Concat>(OutputVector{pastk, inputParams[1]}, 2);
auto concatV = std::make_shared<ov::op::v0::Concat>(OutputVector{pastv, inputParams[2]}, 2);
auto sdp = std::make_shared<ov::opset13::ScaledDotProductAttention>(inputParams[0], concatK, concatV, false);
sdp->set_friendly_name("mha");
auto add = std::make_shared<op::v1::Add>(sdp, op::v0::Constant::create(inType, {1}, {1.0f}));
auto pastk_assign = std::make_shared<op::v6::Assign>(concatK, var_k);
auto pastv_assign = std::make_shared<op::v6::Assign>(concatV, var_v);
pastk_assign->set_friendly_name("pastk_w");
pastv_assign->set_friendly_name("pastv_w");
ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
if (hasShapeOf) {
results.push_back(std::make_shared<ov::op::v0::Result>(pastk_shapeof));
results.push_back(std::make_shared<ov::op::v0::Result>(pastv_shapeof));
}
SinkVector sinks{pastk_assign, pastv_assign};
function = std::make_shared<Function>(results, sinks, inputParams, "ConcatSDP");
targetDevice = ov::test::utils::DEVICE_CPU;
functionRefs = function->clone();
pass::Manager manager;
// decompose ScaledDotProductAttention
manager.register_pass<ov::pass::ScaledDotProductAttentionDecomposition>();
manager.run_passes(functionRefs);
}
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
std::vector<ov::Shape> shapes(4);
shapes[0] = targetInputStaticShapes[0];
shapes[1] = targetInputStaticShapes[0];
shapes[2] = targetInputStaticShapes[0];
shapes[3] = targetInputStaticShapes[1];
SubgraphBaseTest::generate_inputs(shapes);
}
template<typename IT, typename T>
void strided_iota(IT first, size_t n, T value, T stride) {
for (size_t i = 0; i < n; i++) {
*first++ = value;
value += stride;
}
}
void generate(int idx, const std::vector<ov::Shape>& targetInputStaticShapes) {
inputs.clear();
auto create_input = [this] (std::shared_ptr<op::v0::Parameter> param, ov::Shape shape, float val) {
if (param->get_element_type() == element::f32) {
ov::Tensor t{ov::element::f32, shape};
strided_iota(static_cast<float*>(t.data()), t.get_size(), val, 0.1f);
inputs.insert({param, t});
} else {
ov::Tensor t{ov::element::bf16, shape};
strided_iota(static_cast<ov::bfloat16*>(t.data()), t.get_size(), val, 0.1f);
inputs.insert({param, t});
}
};
// q, k, v
create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f);
create_input(function->get_parameters()[1], targetInputStaticShapes[0], idx + 2.0f);
create_input(function->get_parameters()[2], targetInputStaticShapes[0], idx + 3.0f);
create_input(function->get_parameters()[3], targetInputStaticShapes[1], idx + 4.0f);
}
void prepare() {
compile_model();
inferRequest = compiledModel.create_infer_request();
ASSERT_TRUE(inferRequest);
}
void reset() {
for (auto&& state : inferRequest.query_state()) {
state.reset();
}
inferRequest = ov::InferRequest();
}
std::vector<ov::Tensor> run_test(std::shared_ptr<ov::Model> model) {
function = model;
prepare();
std::vector<ov::Tensor> outputs;
int idx = 0;
for (auto&& shapes : targetStaticShapes) {
generate(idx++, shapes);
for (const auto& input : inputs) {
inferRequest.set_tensor(input.first, input.second);
}
inferRequest.infer();
auto outputTensor = inferRequest.get_output_tensor(0);
ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()};
outputTensor.copy_to(copy);
outputs.push_back(copy);
}
reset();
return outputs;
}
};
TEST_P(ConcatSDPTest, CompareWithRefs) {
auto actualOutputs = run_test(function);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1);
CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0);
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
auto expectedOutputs = run_test(functionRefs);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0);
for (size_t i = 0; i < actualOutputs.size(); i++) {
ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold);
}
}
namespace {
const std::vector<std::vector<InputShape>> inputShapes = {
// dynamic batch
{
// B, H, L1, S
{{1, 8, -1, 64}, {{1, 8, 10, 64}, {1, 8, 1, 64}, {1, 8, 1, 64}, {1, 8, 20, 64}, {1, 8, 1, 64}}},
// B, H, L0, S
{{1, 8, -1, 64}, {{1, 8, 0, 64}, {1, 8, 10, 64}, {1, 8, 11, 64}, {1, 8, 12, 64}, {1, 8, 32, 64}}},
},
};
INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTest,
ConcatSDPTest,
::testing::Combine(::testing::Values(ElementType::f32),
::testing::ValuesIn(inputShapes),
::testing::Values(true, false)),
ConcatSDPTest::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions

View File

@ -0,0 +1,201 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils/cpu_test_utils.hpp"
#include "ov_models/builders.hpp"
#include "ov_models/utils/ov_helpers.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
using namespace InferenceEngine;
using namespace CPUTestUtils;
using InputShape = ov::test::InputShape;
using ElementType = ov::element::Type_t;
namespace SubgraphTestsDefinitions {
// ┌────────┐
// │ Param │
// └───┬────┘
// │
// │
// │
// ┌───┴────┐ To simulate different layouts
// │ Eltwise│ ◄─────────────────────────────
// └───┬────┘
// │ No Reorders are expected
// │ ◄───────────────────────────
// │
// ┌───┴────┐
// │ShapeOf │
// └───┬────┘
// │
// │
// │
// ┌───┴────┐
// │ Output │
// └────────┘
typedef std::tuple<
InputShape,
ElementType // Net precision
> ShapeOfAnyLayoutParams;
typedef std::tuple<
ShapeOfAnyLayoutParams,
CPUSpecificParams
> ShapeOfAnyLayoutCPUTestParamsSet;
class ShapeOfAnyLayoutCPUTest : public testing::WithParamInterface<ShapeOfAnyLayoutCPUTestParamsSet>,
virtual public ov::test::SubgraphBaseTest, public CPUTestsBase {
public:
static std::string getTestCaseName(testing::TestParamInfo<ShapeOfAnyLayoutCPUTestParamsSet> obj) {
SubgraphTestsDefinitions::ShapeOfAnyLayoutParams basicParamsSet;
CPUSpecificParams cpuParams;
std::tie(basicParamsSet, cpuParams) = obj.param;
ElementType netPr;
InputShape inputShape;
std::tie(inputShape, netPr) = basicParamsSet;
std::ostringstream result;
result << "ShapeOfTest_";
result << std::to_string(obj.index) << "_";
result << "Prec=" << netPr << "_";
result << CPUTestsBase::getTestCaseName(cpuParams) << "_";
result << "IS=";
result << ov::test::utils::partialShape2str({inputShape.first}) << "_";
result << "TS=(";
for (const auto& shape : inputShape.second) {
result << ov::test::utils::vec2str(shape) << "_";
}
result << ")";
return result.str();
}
protected:
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
ShapeOfAnyLayoutParams basicParamsSet;
CPUSpecificParams cpuParams;
std::tie(basicParamsSet, cpuParams) = this->GetParam();
std::vector<cpu_memory_format_t> eltwiseInFmts, eltwiseOutFmts;
std::tie(eltwiseInFmts, eltwiseOutFmts, priority, selectedType) = cpuParams;
auto netPrecision = ElementType::undefined;
InputShape inputShape;
std::tie(inputShape, netPrecision) = basicParamsSet;
init_input_shapes({inputShape});
inType = ov::element::Type(netPrecision);
outType = ElementType::i32;
selectedType = makeSelectedTypeStr("ref", inType);
ov::ParameterVector params;
for (auto&& shape : inputDynamicShapes)
params.push_back(std::make_shared<ov::op::v0::Parameter>(inType, shape));
//make a stub eltwise node to enforce layout, since ShapeOf just mimic any input layout
auto eltwise = ngraph::builder::makeActivation(params[0], inType, ov::test::utils::ActivationTypes::Relu);
eltwise->get_rt_info() = makeCPUInfo(eltwiseInFmts, eltwiseOutFmts, {});
auto shapeOf = std::make_shared<ngraph::opset3::ShapeOf>(eltwise, ngraph::element::i32);
function = makeNgraphFunction(netPrecision, params, shapeOf, "ShapeOf");
}
};
TEST_P(ShapeOfAnyLayoutCPUTest, CompareWithRefs) {
run();
CheckPluginRelatedResults(compiledModel, "ShapeOf");
CheckNumberOfNodesWithType(compiledModel, "Reorder", 1);
}
namespace {
/* CPU PARAMS */
std::vector<CPUSpecificParams> getCpuInfoForDimsCount(const size_t dimsCount = 3) {
std::vector<CPUSpecificParams> resCPUParams;
const bool avx512_target = with_cpu_x86_avx512f();
if (dimsCount == 5) {
auto blocked_format = avx512_target ? nCdhw16c : nCdhw8c;
resCPUParams.push_back(CPUSpecificParams{{blocked_format}, {blocked_format}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{ndhwc}, {ndhwc}, {}, {}});
} else if (dimsCount == 4) {
auto blocked_format = avx512_target ? nChw16c : nChw8c;
resCPUParams.push_back(CPUSpecificParams{{blocked_format}, {blocked_format}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{nhwc}, {nhwc}, {}, {}});
} else {
auto blocked_format = avx512_target ? nCw16c : nCw8c;
resCPUParams.push_back(CPUSpecificParams{{blocked_format}, {blocked_format}, {}, {}});
resCPUParams.push_back(CPUSpecificParams{{acb}, {acb}, {}, {}});
}
return resCPUParams;
}
const std::vector<ElementType> netPrecisions = {
ElementType::f32
};
std::vector<ov::test::InputShape> inShapesDynamic3d = {
{
{-1, 16, -1},
{
{ 8, 16, 4 },
{ 8, 16, 3 },
{ 8, 16, 2 }
}
}
};
std::vector<ov::test::InputShape> inShapesDynamic4d = {
{
{-1, 16, -1, -1},
{
{ 8, 16, 3, 4 },
{ 8, 16, 3, 3 },
{ 8, 16, 3, 2 }
}
},
};
std::vector<ov::test::InputShape> inShapesDynamic5d = {
{
{ -1, 16, -1, -1, -1 },
{
{ 8, 16, 3, 2, 4 },
{ 8, 16, 3, 2, 3 },
{ 8, 16, 3, 2, 2 }
}
}
};
const auto params5dDynamic = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inShapesDynamic5d),
::testing::ValuesIn(netPrecisions)),
::testing::ValuesIn(getCpuInfoForDimsCount(5)));
const auto params4dDynamic = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inShapesDynamic4d),
::testing::ValuesIn(netPrecisions)),
::testing::ValuesIn(getCpuInfoForDimsCount(4)));
const auto params3dDynamic = ::testing::Combine(
::testing::Combine(
::testing::ValuesIn(inShapesDynamic3d),
::testing::ValuesIn(netPrecisions)),
::testing::ValuesIn(getCpuInfoForDimsCount(3)));
// We don't check static case, because of constant folding
INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf3dAnyLayoutTest, ShapeOfAnyLayoutCPUTest,
params3dDynamic, ShapeOfAnyLayoutCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf4dAnyLayoutTest, ShapeOfAnyLayoutCPUTest,
params4dDynamic, ShapeOfAnyLayoutCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf5dAnyLayoutTest, ShapeOfAnyLayoutCPUTest,
params5dDynamic, ShapeOfAnyLayoutCPUTest::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions

View File

@ -0,0 +1,161 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <common/blocked_desc_creator.h>
#include <cpu_types.h>
#include <edge.h>
#include <gtest/gtest.h>
#include <ie_common.h>
#include <memory_desc/cpu_memory_desc_utils.h>
#include <memory_desc/dnnl_memory_desc.h>
#include <node.h>
#include <nodes/reorder.h>
#include <common/memory_desc_wrapper.hpp>
#include <dnnl.hpp>
#include <utility>
#include "common_test_utils/common_utils.hpp"
#include "cache/multi_cache.h"
#include "ov_models/builders.hpp"
#include "nodes/scaled_attn.h"
#include "nodes/input.h"
#include "graph.h"
#include "cpu_tensor.h"
using namespace ov::intel_cpu;
TEST(ScaledAttnGraphTest, smoke_Check_Scaled_Concat_Noplace) {
auto build_graph = [](const ov::Shape& shape, float* qkv_val, float* past_kv_val) {
auto qkv = ov::op::v0::Constant::create(ov::element::f32, shape, qkv_val);
qkv->set_friendly_name("qkv_const");
auto pastkv = ov::op::v0::Constant::create(ov::element::f32, shape, past_kv_val);
pastkv->set_friendly_name("pastkv_const");
// only need a dynamic parameter but its value will not be used
auto attn = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{-1});
attn->set_friendly_name("attn");
ov::intel_cpu::ScaledDotProductAttentionStub::Config config;
config.fuse_concat = true;
config.is_causal = true;
auto sdpa = std::make_shared<ov::intel_cpu::ScaledDotProductAttentionStub>(ov::OutputVector{qkv, qkv, qkv, attn, pastkv, pastkv}, config);
auto out_qkv = std::make_shared<ov::op::v0::Result>(sdpa->output(0));
out_qkv->set_friendly_name("qkv");
auto out_pastk = std::make_shared<ov::op::v0::Result>(sdpa->output(1));
out_pastk->set_friendly_name("pastk");
auto out_pastv = std::make_shared<ov::op::v0::Result>(sdpa->output(2));
out_pastv->set_friendly_name("pastv");
std::unordered_set<NodePtr> nodes_set;
std::vector<EdgePtr> graph_edges;
auto add_edge = [&](const NodePtr& parent, const NodePtr& child, size_t parentPort, size_t childPort) -> void {
auto edge = std::make_shared<Edge>(parent, child, parentPort, childPort);
child->addEdge(edge);
graph_edges.push_back(edge);
nodes_set.insert(parent);
nodes_set.insert(child);
};
//create graph context
Config conf;
conf.rtCacheCapacity = 0;
auto context = std::make_shared<GraphContext>(conf, nullptr, nullptr, false);
auto qkv_node = std::make_shared<node::Input>(qkv, context);
auto pastkv_node = std::make_shared<node::Input>(pastkv, context);
auto attn_node = std::make_shared<node::Input>(attn, context);
auto sdpa_node = std::make_shared<node::ScaledDotProductAttention>(sdpa, context);
auto out_qkv_node = std::make_shared<node::Input>(out_qkv, context);
auto out_pastk_node = std::make_shared<node::Input>(out_pastk, context);
auto out_pastv_node = std::make_shared<node::Input>(out_pastv, context);
add_edge(qkv_node, sdpa_node, 0, 0);
add_edge(qkv_node, sdpa_node, 0, 1);
add_edge(qkv_node, sdpa_node, 0, 2);
add_edge(attn_node, sdpa_node, 0, 3);
add_edge(pastkv_node, sdpa_node, 0, 4);
add_edge(pastkv_node, sdpa_node, 0, 5);
add_edge(sdpa_node, out_qkv_node, 0, 0);
add_edge(sdpa_node, out_pastk_node, 1, 0);
add_edge(sdpa_node, out_pastv_node, 2, 0);
std::vector<NodePtr> graph_nodes(nodes_set.begin(), nodes_set.end());
Graph graph;
graph.CreateGraph(graph_nodes, graph_edges, context, "test_graph");
return graph;
};
auto run_graph = [] (Graph& graph) {
graph.GetInputNodesMap().begin()->second->redefineOutputMemory(0, {1});
for (auto& node : graph.GetNodes()) {
if (node->isDynamicNode()) {
node->updateShapes();
node->updateDynamicParams();
}
}
graph.Infer();
};
auto check_graph = [] (Graph& graph, std::map<std::string, std::pair<float*, ov::Shape>>& expected) {
auto& outputNodesMap = graph.GetOutputNodesMap();
auto is_same = [] (float a, float b) {
return std::abs(a - b) < 0.0001f;
};
for (auto &outputMap : outputNodesMap) {
auto name = outputMap.first;
if (expected.count(name) == 0) {
continue;
}
auto node = outputMap.second;
auto parentEdge = node->getParentEdgeAt(0);
const auto& memory = parentEdge->getMemoryPtr();
auto size = memory->getSize() / sizeof(float);
auto p = reinterpret_cast<float*>(memory->getData());
for (size_t i = 0; i < size; i++) {
ASSERT_EQ(is_same(p[i], expected.at(name).first[i]), true);
}
ASSERT_EQ(memory->getShape(), ov::intel_cpu::Shape(expected.at(name).second));
}
};
auto find_node_type = [](const Graph& graph, Type type) -> NodePtr {
auto&& nodes = graph.GetNodes();
auto itr =
std::find_if(nodes.begin(), nodes.end(), [=](const NodePtr& node){ return type == node->getType(); });
if (itr == nodes.end()) {
return nullptr;
}
return (*itr);
};
auto strided_iota = [] (float* first, size_t n, float value, float stride) {
for (size_t i = 0; i < n; i++) {
*first++ = value;
value += stride;
}
};
ov::Shape shape{1, 1, 8, 8};
const size_t elements_count = std::accumulate(shape.begin(), shape.end(), size_t{1}, std::multiplies<size_t>());
std::vector<float> val(elements_count * 2);
strided_iota(val.data(), val.size(), -10.0f, 0.1f);
auto graph = build_graph(shape, val.data() + elements_count, val.data());
run_graph(graph);
// if no inplace, the pastk and pastv will concat, check shape and value
ov::Shape expectedShape(shape);
expectedShape[2] *= 2;
std::map<std::string, std::pair<float*, ov::Shape>> expected{
{"pastk", std::make_pair(val.data(), expectedShape)},
{"pastv", std::make_pair(val.data(), expectedShape)}};
check_graph(graph, expected);
auto spd = find_node_type(graph, Type::ScaledDotProductAttention)->getSelectedPrimitiveDescriptor();
ASSERT_EQ(spd->getConfig().outConfs[1].inPlace(), -1);
ASSERT_EQ(spd->getConfig().outConfs[2].inPlace(), -1);
}

View File

@ -0,0 +1,105 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <openvino/opsets/opset13.hpp>
#include <transformations/cpu_opset/common/op/sdp.hpp>
#include <transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include <openvino/pass/manager.hpp>
#include <ie_core.hpp>
#include "common_test_utils/ov_test_utils.hpp"
using namespace testing;
using namespace ov::intel_cpu;
using namespace ov;
static std::shared_ptr<ov::Model> makeSDPA(const ov::PartialShape& inputShape, bool isRef = false, bool hasConvert = false) {
auto q = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
auto k = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
auto v = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
auto init = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
auto var_k = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputShape, element::f32, "pastk"});
std::shared_ptr<ov::Node> pastk = std::make_shared<ov::op::v6::ReadValue>(k, var_k);
auto var_v = std::make_shared<ov::op::util::Variable>(
ov::op::util::VariableInfo{inputShape, element::f32, "pastv"});
std::shared_ptr<ov::Node> pastv = std::make_shared<ov::op::v6::ReadValue>(v, var_v);
Output<ov::Node> concatK, concatV, sdp;
if (hasConvert) {
pastk = std::make_shared<ov::op::v0::Convert>(pastk, element::f32);
pastv = std::make_shared<ov::op::v0::Convert>(pastv, element::f32);
}
if (isRef) {
ov::intel_cpu::ScaledDotProductAttentionStub::Config config;
config.fuse_concat = true;
auto new_node = std::make_shared<ov::intel_cpu::ScaledDotProductAttentionStub>(OutputVector{q, k, v, pastk, pastv}, config);
sdp = new_node->output(0);
concatK = new_node->output(1);
concatV = new_node->output(2);
} else {
concatK = std::make_shared<ov::op::v0::Concat>(OutputVector{pastk, k}, 2);
concatV = std::make_shared<ov::op::v0::Concat>(OutputVector{pastv, v}, 2);
sdp = std::make_shared<ov::opset13::ScaledDotProductAttention>(q, concatK, concatV, false);
}
if (hasConvert) {
concatK = std::make_shared<ov::op::v0::Convert>(concatK, element::f32);
concatV = std::make_shared<ov::op::v0::Convert>(concatV, element::f32);
}
auto pastk_assign = std::make_shared<op::v6::Assign>(concatK, var_k);
auto pastv_assign = std::make_shared<op::v6::Assign>(concatV, var_v);
auto add = std::make_shared<op::v1::Add>(sdp, op::v0::Constant::create(element::f32, {1}, {1.0f}));
ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
SinkVector sinks{pastk_assign, pastv_assign};
return std::make_shared<Model>(results, sinks, ParameterVector{q, k, v, init}, "ConcatSDP");
}
TEST(TransformationTests, StateConcatSDPA) {
std::shared_ptr<ov::Model> f(nullptr), f_ref(nullptr);
{
using namespace ov;
auto inputShape = ov::PartialShape{-1, 8, -1, 64};
{
f = makeSDPA(inputShape);
pass::Manager m;
m.register_pass<ov::pass::InitNodeInfo>();
m.register_pass<StatefulSDPFusion>();
m.run_passes(f);
}
//construct ref interaction
{
f_ref = makeSDPA(inputShape, true);
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
}
TEST(TransformationTests, StateConcatSDPAWithConvert) {
std::shared_ptr<ov::Model> f(nullptr), f_ref(nullptr);
{
using namespace ov;
auto inputShape = ov::PartialShape{-1, 8, -1, 64};
{
f = makeSDPA(inputShape, false, true);
pass::Manager m;
m.register_pass<ov::pass::InitNodeInfo>();
m.register_pass<StatefulSDPFusion>();
m.run_passes(f);
}
//construct ref interaction
{
f_ref = makeSDPA(inputShape, true, true);
}
auto res = compare_functions(f, f_ref);
ASSERT_TRUE(res.first) << res.second;
}
}