[CPU] Add optimized memory management for SDPA KV cache (#21242)
This commit is contained in:
parent
718b5a60bf
commit
405d97e4a5
@ -245,6 +245,49 @@ void MemoryMngrWithReuse::destroy(void *ptr) {
|
||||
dnnl::impl::free(ptr);
|
||||
}
|
||||
|
||||
void* MemoryMngrRealloc::getRawPtr() const noexcept {
|
||||
return m_data.get();
|
||||
}
|
||||
|
||||
void MemoryMngrRealloc::setExtBuff(void *ptr, size_t size) {
|
||||
m_useExternalStorage = true;
|
||||
m_memUpperBound = size;
|
||||
m_data = decltype(m_data)(ptr, release);
|
||||
}
|
||||
|
||||
bool MemoryMngrRealloc::resize(size_t size) {
|
||||
constexpr int cacheLineSize = 64;
|
||||
constexpr size_t growFactor = 2;
|
||||
bool sizeChanged = false;
|
||||
if (size > m_memUpperBound) {
|
||||
size *= growFactor;
|
||||
void *ptr = dnnl::impl::malloc(size, cacheLineSize);
|
||||
if (!ptr) {
|
||||
OPENVINO_THROW("Failed to allocate ", size, " bytes of memory");
|
||||
}
|
||||
|
||||
if (auto src = m_data.get()) {
|
||||
std::memcpy(ptr, src, m_memUpperBound);
|
||||
}
|
||||
|
||||
m_memUpperBound = size;
|
||||
m_useExternalStorage = false;
|
||||
m_data = decltype(m_data)(ptr, destroy);
|
||||
sizeChanged = true;
|
||||
}
|
||||
return sizeChanged;
|
||||
}
|
||||
|
||||
bool MemoryMngrRealloc::hasExtBuffer() const noexcept {
|
||||
return m_useExternalStorage;
|
||||
}
|
||||
|
||||
void MemoryMngrRealloc::release(void *ptr) {}
|
||||
|
||||
void MemoryMngrRealloc::destroy(void *ptr) {
|
||||
dnnl::impl::free(ptr);
|
||||
}
|
||||
|
||||
void* DnnlMemoryMngr::getRawPtr() const noexcept {
|
||||
return m_pMemMngr->getRawPtr();
|
||||
}
|
||||
|
@ -89,6 +89,23 @@ private:
|
||||
static void destroy(void *ptr);
|
||||
};
|
||||
|
||||
class MemoryMngrRealloc : public IMemoryMngr {
|
||||
public:
|
||||
MemoryMngrRealloc() : m_data(nullptr, release) {}
|
||||
void* getRawPtr() const noexcept override;
|
||||
void setExtBuff(void* ptr, size_t size) override;
|
||||
bool resize(size_t size) override;
|
||||
bool hasExtBuffer() const noexcept override;
|
||||
|
||||
private:
|
||||
bool m_useExternalStorage = false;
|
||||
size_t m_memUpperBound = 0ul;
|
||||
std::unique_ptr<void, void (*)(void *)> m_data;
|
||||
|
||||
static void release(void *ptr);
|
||||
static void destroy(void *ptr);
|
||||
};
|
||||
|
||||
class IMemoryMngrObserver : public IMemoryMngr {
|
||||
public:
|
||||
virtual void registerMemory(Memory* memPtr) = 0;
|
||||
|
@ -214,6 +214,7 @@ static const TypeToNameMap& get_type_to_name_tbl() {
|
||||
{ "Unique", Type::Unique},
|
||||
{ "Ngram", Type::Ngram},
|
||||
{ "ScaledDotProductAttention", Type::ScaledDotProductAttention},
|
||||
{ "ScaledDotProductAttentionStub", Type::ScaledDotProductAttention},
|
||||
{ "RoPE", Type::RoPE},
|
||||
};
|
||||
return type_to_name_tbl;
|
||||
|
@ -6,6 +6,7 @@
|
||||
#include "transformations/cpu_opset/common/op/fully_connected.hpp"
|
||||
#include "transformations/cpu_opset/common/op/leaky_relu.hpp"
|
||||
#include "transformations/cpu_opset/common/op/power_static.hpp"
|
||||
#include "transformations/cpu_opset/common/op/sdp.hpp"
|
||||
#include "transformations/cpu_opset/common/op/swish_cpu.hpp"
|
||||
#include "transformations/cpu_opset/common/op/ngram.hpp"
|
||||
#include "transformations/cpu_opset/x64/op/mha.hpp"
|
||||
@ -59,6 +60,7 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
|
||||
NGRAPH_OP(NgramNode, ov::intel_cpu)
|
||||
NGRAPH_OP_X64(MHANode, ov::intel_cpu)
|
||||
NGRAPH_OP_X64(InteractionNode, ov::intel_cpu)
|
||||
NGRAPH_OP_X64(ScaledDotProductAttentionStub, ov::intel_cpu)
|
||||
#undef NGRAPH_OP
|
||||
|
||||
return opset;
|
||||
|
@ -1441,7 +1441,7 @@ void Graph::GetPerfData(std::vector<ov::ProfilingInfo>& perfMap) const {
|
||||
}
|
||||
}
|
||||
|
||||
void Graph::RemoveEdge(EdgePtr& edge) {
|
||||
void Graph::RemoveEdge(const EdgePtr& edge) {
|
||||
for (auto it = graphEdges.begin(); it != graphEdges.end(); it++) {
|
||||
if ((*it) == edge) {
|
||||
edge->drop();
|
||||
@ -1881,9 +1881,9 @@ void Graph::resolveInPlaceDirection(const NodePtr& node) const {
|
||||
void Graph::SearchInternalStateNodes() {
|
||||
for (auto&& node : graphNodes) {
|
||||
if (node->getType() == Type::MemoryInput) {
|
||||
auto cur_node = std::dynamic_pointer_cast<node::MemoryInput>(node);
|
||||
auto cur_node = std::dynamic_pointer_cast<node::MemoryStateNode>(node);
|
||||
if (!cur_node) {
|
||||
OPENVINO_THROW("Cannot cast ", node->getName(), " to MemoryInput");
|
||||
OPENVINO_THROW("Cannot cast ", node->getName(), " to MemoryStateNode");
|
||||
}
|
||||
internalStateNodes.insert({cur_node->getId(), cur_node});
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ namespace intel_cpu {
|
||||
|
||||
class SyncInferRequest;
|
||||
namespace node {
|
||||
class MemoryNode;
|
||||
class MemoryStateNode;
|
||||
} // namespace node
|
||||
|
||||
class Graph {
|
||||
@ -123,7 +123,7 @@ public:
|
||||
|
||||
void RemoveDroppedNodes();
|
||||
void RemoveDroppedEdges();
|
||||
void RemoveEdge(EdgePtr& edge);
|
||||
void RemoveEdge(const EdgePtr& edge);
|
||||
void DropNode(const NodePtr& node);
|
||||
void DropDWConvNode(const NodePtr& node);
|
||||
|
||||
@ -197,7 +197,7 @@ public:
|
||||
}
|
||||
|
||||
Status getStatus() const {return status;}
|
||||
const std::unordered_map<std::string, std::shared_ptr<node::MemoryNode>>&
|
||||
const std::unordered_map<std::string, std::shared_ptr<node::MemoryStateNode>>&
|
||||
getInternalStateNodes() const {
|
||||
return internalStateNodes;
|
||||
}
|
||||
@ -259,7 +259,7 @@ private:
|
||||
std::map<std::string, NodePtr> outputNodesMap;
|
||||
|
||||
std::unordered_map<std::string, ProxyMemoryMngrPtr> outputNodesMemMngrMap;
|
||||
std::unordered_map<std::string, std::shared_ptr<node::MemoryNode>> internalStateNodes;
|
||||
std::unordered_map<std::string, std::shared_ptr<node::MemoryStateNode>> internalStateNodes;
|
||||
|
||||
// these node pointers (from graphNodes) are to avoid regular checking for
|
||||
// constantness of nodes in Infer methods and calls of
|
||||
|
@ -22,6 +22,7 @@
|
||||
#include "nodes/reduce.h"
|
||||
#include "nodes/input.h"
|
||||
#include "nodes/rnn.h"
|
||||
#include "nodes/memory.hpp"
|
||||
#include "nodes/common/cpu_convert.h"
|
||||
|
||||
#include "onednn/dnnl.h"
|
||||
@ -182,6 +183,18 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
|
||||
RemoveSameConvert(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
|
||||
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveMemoryInputConvert");
|
||||
RemoveMemoryInputConvert(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
|
||||
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveConvertMemoryOutput");
|
||||
RemoveConvertMemoryOutput(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
|
||||
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "MatchSdpaKvCache");
|
||||
MatchSdpaKvCache(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
|
||||
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "RemoveDroppedEdges");
|
||||
graph.RemoveDroppedEdges();
|
||||
}
|
||||
@ -2710,5 +2723,153 @@ void GraphOptimizer::RemoveSameConvert(Graph& graph) {
|
||||
}
|
||||
}
|
||||
|
||||
void GraphOptimizer::RemoveMemoryInputConvert(Graph &graph) {
|
||||
auto& graphNodes = graph.GetNodes();
|
||||
|
||||
auto isSuitableNode = [](const NodePtr& node) {
|
||||
if (Type::Convert != node->getType()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto parent = node->getParentEdgeAt(0)->getParent();
|
||||
if (Type::MemoryInput != parent->getType()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < graphNodes.size(); i++) {
|
||||
auto node = graphNodes[i];
|
||||
|
||||
if (!isSuitableNode(node)) {
|
||||
continue;
|
||||
}
|
||||
graph.DropNode(node);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphOptimizer::RemoveConvertMemoryOutput(Graph &graph) {
|
||||
auto& graphNodes = graph.GetNodes();
|
||||
|
||||
auto isSuitableNode = [](const NodePtr& node) {
|
||||
if (Type::Convert != node->getType()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto&& childEdges = node->getChildEdgesAtPort(0);
|
||||
for (auto&& edge : childEdges) {
|
||||
if (Type::MemoryOutput != edge->getChild()->getType()) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < graphNodes.size(); i++) {
|
||||
auto node = graphNodes[i];
|
||||
|
||||
if (!isSuitableNode(node)) {
|
||||
continue;
|
||||
}
|
||||
graph.DropNode(node);
|
||||
}
|
||||
}
|
||||
|
||||
void GraphOptimizer::MatchSdpaKvCache(Graph &graph) {
|
||||
auto& graphNodes = graph.GetNodes();
|
||||
|
||||
auto isSuitableMemInput = [](const NodePtr& node) -> bool {
|
||||
if (Type::MemoryInput != node->getType()) {
|
||||
return false;
|
||||
}
|
||||
NodePtr childSdpa = nullptr;
|
||||
auto&& childEdges = node->getChildEdgesAtPort(0);
|
||||
for (auto&& item : childEdges) {
|
||||
auto childNode = item->getChild();
|
||||
if (!one_of(childNode->getType(), Type::ScaledDotProductAttention, Type::ShapeOf)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (Type::ScaledDotProductAttention == childNode->getType()) {
|
||||
if (childSdpa && childSdpa != childNode) {
|
||||
//only one child SDPA supported
|
||||
return false;
|
||||
}
|
||||
childSdpa = childNode;
|
||||
}
|
||||
}
|
||||
|
||||
CPU_GRAPH_OPTIMIZER_SCOPE(MatchSdpaKvCache_isSuitableMemInput);
|
||||
|
||||
auto memInputNode = std::dynamic_pointer_cast<node::MemoryInputBase>(node);
|
||||
OPENVINO_ASSERT(memInputNode, "MemoryInput node ", node->getName(), " has unexpected dynamic type");
|
||||
auto& memOutputNode = memInputNode->getOutputNode();
|
||||
auto memOutputParent = memOutputNode.getParentEdgeAt(0)->getParent();
|
||||
if (memOutputParent != childSdpa) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
for (size_t i = 0; i < graphNodes.size(); i++) {
|
||||
auto node = graphNodes[i];
|
||||
if (!isSuitableMemInput(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
CPU_GRAPH_OPTIMIZER_SCOPE(MatchSdpaKvCache_Node);
|
||||
|
||||
// Node is already modified
|
||||
if (auto sdpaMemInput = std::dynamic_pointer_cast<node::MemoryInputSDPA>(node)) {
|
||||
continue;
|
||||
}
|
||||
|
||||
auto memInputNode = std::dynamic_pointer_cast<node::MemoryInputBase>(node);
|
||||
OPENVINO_ASSERT(memInputNode, "MemoryInput node ", node->getName(), " has unexpected dynamic type");
|
||||
|
||||
ov::optional<Shape> input_shape;
|
||||
ov::optional<ov::element::Type> input_prc;
|
||||
|
||||
if (!node->getParentEdges().empty()) {
|
||||
input_shape = ov::optional<Shape>(node->getInputShapeAtPort(0));
|
||||
input_prc = ov::optional<ov::element::Type>(node->getOriginalInputPrecisionAtPort(0));
|
||||
}
|
||||
|
||||
auto memInputSdpa = std::make_shared<MemoryInputSDPA>(
|
||||
memInputNode->getId(),
|
||||
memInputNode->getName(),
|
||||
memInputNode->getTypeStr(),
|
||||
memInputNode->getOutputShapeAtPort(0),
|
||||
memInputNode->getOriginalOutputPrecisionAtPort(0),
|
||||
graph.getGraphContext(),
|
||||
input_shape,
|
||||
input_prc);
|
||||
|
||||
if (!memInputNode->getParentEdges().empty()) {
|
||||
auto parentEdge = memInputNode->getParentEdgeAt(0);
|
||||
auto newEdge = std::make_shared<Edge>(parentEdge->getParent(), memInputSdpa, parentEdge->getInputNum(), 0);
|
||||
memInputSdpa->addEdge(newEdge);
|
||||
graph.GetEdges().push_back(newEdge);
|
||||
graph.RemoveEdge(parentEdge);
|
||||
}
|
||||
|
||||
for (auto&& edge : memInputNode->getChildEdgesAtPort(0)) {
|
||||
auto newEdge = std::make_shared<Edge>(memInputSdpa, edge->getChild(), 0, edge->getOutputNum());
|
||||
memInputSdpa->addEdge(newEdge);
|
||||
graph.GetEdges().push_back(newEdge);
|
||||
graph.RemoveEdge(edge);
|
||||
}
|
||||
|
||||
//link with memory output
|
||||
auto& memOutput = memInputNode->getOutputNode();
|
||||
memInputSdpa->registerOutputNode(&memOutput);
|
||||
|
||||
graph.GetNodes().push_back(memInputSdpa);
|
||||
graph.DropNode(memInputNode);
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -49,6 +49,9 @@ private:
|
||||
void MergeTransposeAndReorder(Graph &graph);
|
||||
void reshapeRnnSeq(Graph &graph);
|
||||
void RemoveSameConvert(Graph &graph);
|
||||
void RemoveMemoryInputConvert(Graph &graph);
|
||||
void RemoveConvertMemoryOutput(Graph &graph);
|
||||
void MatchSdpaKvCache(Graph &graph);
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
|
@ -61,7 +61,7 @@ void SyncInferRequest::create_infer_request() {
|
||||
init_tensor(it.first);
|
||||
}
|
||||
|
||||
//create states according to the list of the MemoryNodes
|
||||
//create states according to the list of the MemoryStateNodes
|
||||
for (auto&& node : m_graph->getInternalStateNodes()) {
|
||||
m_memory_states.emplace_back(node.second->makeState());
|
||||
}
|
||||
|
@ -14,13 +14,81 @@ using namespace InferenceEngine;
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
VariableStateDoubleBuffer::VariableStateDoubleBuffer(std::string name,
|
||||
const MemBuilder& mem_build,
|
||||
MemoryDescPtr external_desc,
|
||||
MemoryCPtr init_val) :
|
||||
IVariableState{name}, m_external_desc{external_desc} {
|
||||
reset_prime_mem(mem_build());
|
||||
reset_second_mem(mem_build());
|
||||
VariableStateBase::VariableStateBase(const std::string& name, const MemoryDescPtr& external_desc) :
|
||||
IVariableState{name} , m_external_desc{external_desc} {}
|
||||
|
||||
MemoryDescPtr VariableStateBase::to_static(const MemoryDescPtr& desc) {
|
||||
if (!desc->isDefined()) {
|
||||
auto&& current_dims = desc->getShape().getDims();
|
||||
VectorDims new_dims(current_dims.size());
|
||||
std::transform(current_dims.begin(), current_dims.end(), new_dims.begin(), [](Dim x) {
|
||||
return x == Shape::UNDEFINED_DIM ? 0 : x; });
|
||||
|
||||
return desc->cloneWithNewDims(new_dims, true);
|
||||
}
|
||||
return desc;
|
||||
}
|
||||
|
||||
const dnnl::engine& VariableStateBase::get_engine() {
|
||||
static const dnnl::engine eng(dnnl::engine::kind::cpu, 0);
|
||||
return eng;
|
||||
}
|
||||
|
||||
void VariableStateBase::set_state(const ov::SoPtr<ov::ITensor>& state) {
|
||||
m_state = state; // simply to extend the lifetime
|
||||
auto state_desc = MemoryDescUtils::generateCpuBlockedMemoryDesc(m_state);
|
||||
|
||||
const auto& shape = state_desc->getShape();
|
||||
|
||||
if (input_mem()->getShape() != shape) {
|
||||
auto new_desc = internal_desc()->cloneWithNewDims(shape.getStaticDims());
|
||||
input_mem()->redefineDesc(new_desc);
|
||||
}
|
||||
|
||||
auto src = m_state->data();
|
||||
|
||||
Memory mem(get_engine(), state_desc, src);
|
||||
input_mem()->load(mem);
|
||||
}
|
||||
|
||||
ov::SoPtr<ov::ITensor> VariableStateBase::get_state() const {
|
||||
const auto& current_dims = internal_state_mem()->getStaticDims();
|
||||
auto current_ext_desc = m_external_desc->cloneWithNewDims(current_dims);
|
||||
auto current_internal_desc = internal_state_mem()->getDescPtr();
|
||||
|
||||
if (current_ext_desc->isCompatible(*current_internal_desc)) {
|
||||
return std::make_shared<Tensor>(internal_state_mem());
|
||||
}
|
||||
|
||||
//test precision
|
||||
{
|
||||
auto internal_prc = current_internal_desc->getPrecision();
|
||||
auto tmp_desc = current_ext_desc->cloneWithNewPrecision(internal_prc);
|
||||
if (tmp_desc->isCompatible(*current_internal_desc)) {
|
||||
auto mem = std::make_shared<Memory>(get_engine(), current_ext_desc);
|
||||
size_t elements_to_convert = internal_state_mem()->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
|
||||
auto external_prc = current_ext_desc->getPrecision();
|
||||
|
||||
cpu_convert(internal_state_mem()->getData(), mem->getData(), internal_prc, external_prc, elements_to_convert);
|
||||
return std::make_shared<Tensor>(mem);
|
||||
}
|
||||
}
|
||||
|
||||
//reorder
|
||||
auto mem = std::make_shared<Memory>(get_engine(), current_ext_desc);
|
||||
mem->load(*(internal_state_mem()));
|
||||
return std::make_shared<Tensor>(mem);
|
||||
}
|
||||
|
||||
VariableStateDoubleBuffer::VariableStateDoubleBuffer(const std::string& name,
|
||||
const MemoryPtr& first_buffer,
|
||||
const MemoryPtr& second_buffer,
|
||||
const MemoryDescPtr& external_desc,
|
||||
const MemoryCPtr& init_val) :
|
||||
VariableStateBase(name, external_desc) {
|
||||
OPENVINO_ASSERT(first_buffer && second_buffer);
|
||||
reset_prime_mem(first_buffer);
|
||||
reset_second_mem(second_buffer);
|
||||
m_internal_desc = prime_mem()->getDescPtr();
|
||||
auto&& shape = m_internal_desc->getShape();
|
||||
//TODO what if by some reason we already have internal static state while the node is dynamic, is it even possible?
|
||||
@ -38,58 +106,6 @@ VariableStateDoubleBuffer::VariableStateDoubleBuffer(std::string name,
|
||||
}
|
||||
}
|
||||
|
||||
void VariableStateDoubleBuffer::set_state(const ov::SoPtr<ov::ITensor>& state) {
|
||||
m_state = state; // simply to extend the lifetime
|
||||
auto state_desc = MemoryDescUtils::generateCpuBlockedMemoryDesc(m_state);
|
||||
|
||||
const auto& shape = state_desc->getShape();
|
||||
|
||||
if (prime_mem()->getShape() != shape) {
|
||||
auto new_desc = m_internal_desc->cloneWithNewDims(shape.getStaticDims());
|
||||
prime_mem()->redefineDesc(new_desc);
|
||||
}
|
||||
|
||||
auto src = m_state->data();
|
||||
|
||||
Memory mem(get_engine(), state_desc, src);
|
||||
prime_mem()->load(mem);
|
||||
}
|
||||
|
||||
const dnnl::engine& VariableStateDoubleBuffer::get_engine() const {
|
||||
static const dnnl::engine eng(dnnl::engine::kind::cpu, 0);
|
||||
return eng;
|
||||
}
|
||||
|
||||
ov::SoPtr<ov::ITensor> VariableStateDoubleBuffer::get_state() const {
|
||||
//TODO , in general case must be synchronized
|
||||
const auto& current_dims = prime_mem()->getStaticDims();
|
||||
auto current_ext_desc = m_external_desc->cloneWithNewDims(current_dims);
|
||||
auto current_internal_desc = prime_mem()->getDescPtr();
|
||||
|
||||
if (current_ext_desc->isCompatible(*current_internal_desc)) {
|
||||
return std::make_shared<Tensor>(prime_mem());
|
||||
}
|
||||
|
||||
//test precision
|
||||
{
|
||||
auto internal_prc = current_internal_desc->getPrecision();
|
||||
auto tmp_desc = current_ext_desc->cloneWithNewPrecision(internal_prc);
|
||||
if (tmp_desc->isCompatible(*current_internal_desc)) {
|
||||
auto mem = std::make_shared<Memory>(get_engine(), current_ext_desc);
|
||||
size_t elements_to_convert = prime_mem()->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
|
||||
auto external_prc = current_ext_desc->getPrecision();
|
||||
|
||||
cpu_convert(prime_mem()->getData(), mem->getData(), internal_prc, external_prc, elements_to_convert);
|
||||
return std::make_shared<Tensor>(mem);
|
||||
}
|
||||
}
|
||||
|
||||
//reorder
|
||||
auto mem = std::make_shared<Memory>(get_engine(), current_ext_desc);
|
||||
mem->load(*(prime_mem()));
|
||||
return std::make_shared<Tensor>(mem);
|
||||
}
|
||||
|
||||
void VariableStateDoubleBuffer::reset() {
|
||||
auto new_desc = to_static(m_internal_desc);
|
||||
for (auto&& mem : m_internal_mem) {
|
||||
@ -100,18 +116,6 @@ void VariableStateDoubleBuffer::reset() {
|
||||
}
|
||||
}
|
||||
|
||||
MemoryDescPtr VariableStateDoubleBuffer::to_static(const MemoryDescPtr& desc) {
|
||||
if (!desc->isDefined()) {
|
||||
auto&& current_dims = desc->getShape().getDims();
|
||||
VectorDims new_dims(current_dims.size());
|
||||
std::transform(current_dims.begin(), current_dims.end(), new_dims.begin(), [](Dim x) {
|
||||
return x == Shape::UNDEFINED_DIM ? 0 : x; });
|
||||
|
||||
return desc->cloneWithNewDims(new_dims, true);
|
||||
}
|
||||
return desc;
|
||||
}
|
||||
|
||||
void VariableStateDoubleBuffer::commit() {
|
||||
buffer_num ^= 0x01;
|
||||
}
|
||||
@ -128,5 +132,59 @@ MemoryDescPtr VariableStateDoubleBuffer::internal_desc() const {
|
||||
return m_internal_desc;
|
||||
}
|
||||
|
||||
MemoryPtr VariableStateDoubleBuffer::internal_state_mem() const {
|
||||
return prime_mem();
|
||||
}
|
||||
|
||||
VariableStateSingleBuffer::VariableStateSingleBuffer(const std::string& name,
|
||||
const MemoryPtr& buffer,
|
||||
const MemoryDescPtr& external_desc,
|
||||
const MemoryCPtr& init_val) :
|
||||
VariableStateBase(name, external_desc) {
|
||||
OPENVINO_ASSERT(buffer);
|
||||
m_internal_mem = buffer;
|
||||
m_internal_desc = m_internal_mem->getDescPtr();
|
||||
auto&& shape = m_internal_desc->getShape();
|
||||
//TODO what if by some reason we already have internal static state while the node is dynamic, is it even possible?
|
||||
|
||||
if (shape.isStatic()) {
|
||||
if (init_val) {
|
||||
m_internal_mem->load(*init_val);
|
||||
} else {
|
||||
m_internal_mem->nullify();
|
||||
}
|
||||
} else {
|
||||
//in the case of the original desc has dynamic shape we create an empty tensor
|
||||
auto new_desc = to_static(m_internal_desc);
|
||||
m_internal_mem->redefineDesc(new_desc);
|
||||
}
|
||||
}
|
||||
|
||||
void VariableStateSingleBuffer::reset() {
|
||||
auto new_desc = to_static(m_internal_desc);
|
||||
m_internal_mem->redefineDesc(new_desc);
|
||||
m_internal_mem->nullify();
|
||||
}
|
||||
|
||||
MemoryPtr VariableStateSingleBuffer::input_mem() {
|
||||
return m_internal_mem;
|
||||
}
|
||||
|
||||
MemoryPtr VariableStateSingleBuffer::output_mem() {
|
||||
return m_internal_mem;
|
||||
}
|
||||
|
||||
MemoryDescPtr VariableStateSingleBuffer::internal_desc() const {
|
||||
return m_internal_desc;
|
||||
}
|
||||
|
||||
MemoryPtr VariableStateSingleBuffer::internal_state_mem() const {
|
||||
return m_internal_mem;
|
||||
}
|
||||
|
||||
void VariableStateSingleBuffer::commit() {
|
||||
//nothing to do
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -28,20 +28,35 @@ public:
|
||||
virtual MemoryDescPtr internal_desc() const = 0;
|
||||
};
|
||||
|
||||
class VariableStateDoubleBuffer : public IVariableState {
|
||||
class VariableStateBase : public IVariableState {
|
||||
public:
|
||||
using MemBuilder = std::function<MemoryPtr(void)>;
|
||||
VariableStateBase(const std::string& name, const MemoryDescPtr& external_desc);
|
||||
|
||||
public:
|
||||
VariableStateDoubleBuffer(std::string name,
|
||||
const MemBuilder& mem_build,
|
||||
MemoryDescPtr external_desc,
|
||||
MemoryCPtr init_val);
|
||||
//ov::IVariableState
|
||||
void reset() override;
|
||||
void set_state(const ov::SoPtr<ov::ITensor>& state) override;
|
||||
ov::SoPtr<ov::ITensor> get_state() const override;
|
||||
|
||||
protected:
|
||||
virtual MemoryPtr internal_state_mem() const = 0;
|
||||
|
||||
static MemoryDescPtr to_static(const MemoryDescPtr& desc);
|
||||
static const dnnl::engine& get_engine();
|
||||
|
||||
protected:
|
||||
MemoryDescPtr m_external_desc;
|
||||
};
|
||||
|
||||
class VariableStateDoubleBuffer : public VariableStateBase {
|
||||
public:
|
||||
VariableStateDoubleBuffer(const std::string& name,
|
||||
const MemoryPtr& first_buffer,
|
||||
const MemoryPtr& second_buffer,
|
||||
const MemoryDescPtr& external_desc,
|
||||
const MemoryCPtr& init_val);
|
||||
|
||||
//ov::IVariableState
|
||||
void reset() override;
|
||||
|
||||
//ov::intel_cpu::IVariableState
|
||||
void commit() override;
|
||||
|
||||
@ -50,8 +65,6 @@ public:
|
||||
MemoryDescPtr internal_desc() const override;
|
||||
|
||||
private:
|
||||
static MemoryDescPtr to_static(const MemoryDescPtr& desc);
|
||||
|
||||
void reset_prime_mem(const MemoryPtr& mem) {
|
||||
m_internal_mem[buffer_num] = mem;
|
||||
}
|
||||
@ -68,16 +81,38 @@ private:
|
||||
return m_internal_mem[buffer_num ^ 0x1];
|
||||
}
|
||||
|
||||
|
||||
const dnnl::engine& get_engine() const;
|
||||
MemoryPtr internal_state_mem() const override;
|
||||
|
||||
private:
|
||||
MemoryDescPtr m_external_desc;
|
||||
MemoryDescPtr m_internal_desc; //mem desc required by the graph internal tensor
|
||||
std::array<MemoryPtr, 2> m_internal_mem{};
|
||||
size_t buffer_num = 0;
|
||||
};
|
||||
|
||||
class VariableStateSingleBuffer : public VariableStateBase {
|
||||
public:
|
||||
VariableStateSingleBuffer(const std::string& name,
|
||||
const MemoryPtr& buffer,
|
||||
const MemoryDescPtr& external_desc,
|
||||
const MemoryCPtr& init_val);
|
||||
//ov::IVariableState
|
||||
void reset() override;
|
||||
|
||||
//ov::intel_cpu::IVariableState
|
||||
void commit() override;
|
||||
|
||||
MemoryPtr input_mem() override;
|
||||
MemoryPtr output_mem() override;
|
||||
MemoryDescPtr internal_desc() const override;
|
||||
|
||||
private:
|
||||
MemoryPtr internal_state_mem() const override;
|
||||
|
||||
private:
|
||||
MemoryDescPtr m_internal_desc; //mem desc required by the graph internal tensor
|
||||
MemoryPtr m_internal_mem;
|
||||
};
|
||||
|
||||
using MemStatePtr = std::shared_ptr<IVariableState>;
|
||||
using MemStateCPtr = std::shared_ptr<const IVariableState>;
|
||||
} // namespace intel_cpu
|
||||
|
@ -1673,7 +1673,7 @@ void Node::updateLastInputDims() {
|
||||
}
|
||||
|
||||
for (size_t i = 0; i < lastInputDims.size(); i++)
|
||||
lastInputDims[i] = getParentEdgesAtPort(i)[0]->getMemory().getStaticDims();
|
||||
lastInputDims[i] = getParentEdgesAtPort(i)[0]->getMemory().getDesc().getShape().getDims();
|
||||
}
|
||||
|
||||
bool Node::canFuseSimpleOperation(const NodePtr& node) const {
|
||||
|
@ -0,0 +1,39 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "arbitrary_order_desc_creator.h"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
ArbitraryOrderDescCreator::ArbitraryOrderDescCreator(VectorDims order) :
|
||||
m_order(std::move(order)) {
|
||||
OPENVINO_ASSERT(std::adjacent_find(m_order.begin(), m_order.end()) == m_order.end(),
|
||||
"Can't construct ArbitraryOrderDescCreator, order vector contains repetitive elements",
|
||||
vec2str(m_order));
|
||||
}
|
||||
|
||||
CpuBlockedMemoryDesc
|
||||
ArbitraryOrderDescCreator::createDesc(const ov::element::Type& precision, const Shape& srcShape) const {
|
||||
auto&& dims = srcShape.getDims();
|
||||
OPENVINO_ASSERT(dims.size() == m_order.size(),
|
||||
"Couldn't create a tensor descriptor, shape and order size mismatch. Shape: ",
|
||||
vec2str(dims),
|
||||
" order: ",
|
||||
vec2str(m_order));
|
||||
|
||||
VectorDims blkDims(dims.size());
|
||||
for (size_t i = 0; i < dims.size(); ++i) {
|
||||
blkDims[i] = dims[m_order[i]];
|
||||
}
|
||||
|
||||
return CpuBlockedMemoryDesc(precision, srcShape, blkDims, m_order);
|
||||
}
|
||||
|
||||
size_t ArbitraryOrderDescCreator::getMinimalRank() const {
|
||||
return m_order.size();
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -0,0 +1,24 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "blocked_desc_creator.h"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
class ArbitraryOrderDescCreator : public BlockedDescCreator {
|
||||
public:
|
||||
ArbitraryOrderDescCreator(VectorDims order);
|
||||
|
||||
CpuBlockedMemoryDesc createDesc(const ov::element::Type& precision, const Shape& srcShape) const override;
|
||||
size_t getMinimalRank() const override;
|
||||
|
||||
private:
|
||||
VectorDims m_order;
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -395,6 +395,10 @@ Input::Input(const Shape& shape,
|
||||
const GraphContext::CPtr context)
|
||||
: Node(type, name, context) {
|
||||
constant = ConstantType::NoConst;
|
||||
isDynamic = shape.isDynamic();
|
||||
if (isDynamic) {
|
||||
shapeInference = PassThroughShapeInferFactory().makeShapeInfer();
|
||||
}
|
||||
if (getType() == Type::Input) {
|
||||
outputShapes.emplace_back(shape);
|
||||
addOriginalOutputPrecision(prc);
|
||||
|
@ -11,6 +11,8 @@
|
||||
#include "utils/general_utils.h"
|
||||
#include "memory_desc/dnnl_blocked_memory_desc.h"
|
||||
#include "utils/ngraph_utils.hpp"
|
||||
#include "shape_inference/shape_inference_pass_through.hpp"
|
||||
#include "common/arbitrary_order_desc_creator.h"
|
||||
|
||||
using namespace dnnl;
|
||||
using namespace InferenceEngine;
|
||||
@ -23,9 +25,9 @@ std::mutex MemoryNodeVirtualEdge::holderMutex;
|
||||
|
||||
MemoryNode::MemoryNode(const std::shared_ptr<ov::Node>& op) {
|
||||
if (auto assignOp = ov::as_type_ptr<ov::op::util::AssignBase>(op)) {
|
||||
_id = assignOp->get_variable_id();
|
||||
m_id = assignOp->get_variable_id();
|
||||
} else if (auto readValueOp = ov::as_type_ptr<ov::op::util::ReadValueBase>(op)) {
|
||||
_id = readValueOp->get_variable_id();
|
||||
m_id = readValueOp->get_variable_id();
|
||||
} else {
|
||||
OPENVINO_THROW("Unexpected ov::Node type: ", op->get_type_info().name, " in MemoryNode");
|
||||
}
|
||||
@ -61,7 +63,7 @@ MemoryOutput::~MemoryOutput() {
|
||||
MemoryNodeVirtualEdge::remove(this, holder);
|
||||
}
|
||||
|
||||
MemoryInput& MemoryOutput::getInputNode() {
|
||||
MemoryInputBase& MemoryOutput::getInputNode() {
|
||||
OPENVINO_ASSERT(inputNode, "MemoryOutput ", getName(), " doesn't have sibling input");
|
||||
return *inputNode;
|
||||
}
|
||||
@ -196,19 +198,19 @@ void MemoryOutput::executeDynamicImpl(dnnl::stream strm) {
|
||||
execute(strm);
|
||||
}
|
||||
|
||||
void MemoryOutput::registerInputNode(MemoryInput* node) {
|
||||
void MemoryOutput::registerInputNode(MemoryInputBase* node) {
|
||||
if (inputNode == node) { return; }
|
||||
if (inputNode) { inputNode->deregisterSibling(this); }
|
||||
inputNode = node;
|
||||
inputNode->registerOutputNode(this);
|
||||
}
|
||||
|
||||
void MemoryOutput::deregisterSibling(MemoryNode* node) {
|
||||
void MemoryOutput::deregisterSibling(MemoryInputBase* node) {
|
||||
if (node == inputNode) { inputNode = nullptr; }
|
||||
}
|
||||
|
||||
|
||||
bool MemoryInput::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
|
||||
bool MemoryInputBase::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
if (!one_of(op->get_type_info(),
|
||||
ov::op::v3::ReadValue::get_type_info_static(),
|
||||
@ -222,8 +224,8 @@ bool MemoryInput::isSupportedOperation(const std::shared_ptr<const ov::Node>& op
|
||||
return true;
|
||||
}
|
||||
|
||||
MemoryInput::MemoryInput(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr ctx)
|
||||
: Input(op, ctx), MemoryNode(op) {
|
||||
MemoryInputBase::MemoryInputBase(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr ctx)
|
||||
: Input(op, ctx), MemoryStateNode(op) {
|
||||
std::string errorMessage;
|
||||
if (!isSupportedOperation(op, errorMessage)) {
|
||||
OPENVINO_THROW_NOT_IMPLEMENTED(errorMessage);
|
||||
@ -233,7 +235,32 @@ MemoryInput::MemoryInput(const std::shared_ptr<ov::Node>& op, const GraphContext
|
||||
}
|
||||
}
|
||||
|
||||
void MemoryInput::createPrimitive() {
|
||||
MemoryInputBase::MemoryInputBase(const std::string id,
|
||||
const std::string& name,
|
||||
const std::string& type,
|
||||
const Shape& output_shape,
|
||||
const ov::element::Type& output_prc,
|
||||
const GraphContext::CPtr context,
|
||||
const ov::optional<Shape>& input_shape,
|
||||
const ov::optional<ov::element::Type>& input_prc) :
|
||||
Input(output_shape, output_prc, name, type, context), MemoryStateNode(id) {
|
||||
outputShapes.emplace_back(output_shape);
|
||||
addOriginalOutputPrecision(output_prc);
|
||||
if (input_shape) {
|
||||
inputShapes.push_back(*input_shape);
|
||||
isDynamic = isDynamic || input_shape->isDynamic();
|
||||
if (isDynamic && !shapeInference) {
|
||||
shapeInference = PassThroughShapeInferFactory().makeShapeInfer();
|
||||
}
|
||||
}
|
||||
if (input_prc) {
|
||||
addOriginalInputPrecision(*input_prc);
|
||||
}
|
||||
// We don't need to use a virtual edge since this constructor is used in transformations and
|
||||
// this is their responsibility to link the input/output nodes properly
|
||||
}
|
||||
|
||||
void MemoryInputBase::createPrimitive() {
|
||||
Input::createPrimitive();
|
||||
if (!inputShapes.empty()) {
|
||||
auto parentEdge = getParentEdgeAt(0);
|
||||
@ -244,63 +271,7 @@ void MemoryInput::createPrimitive() {
|
||||
}
|
||||
}
|
||||
|
||||
void MemoryInput::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
|
||||
auto&& shape = getOutputShapeAtPort(0);
|
||||
auto precision = getOriginalOutputPrecisionAtPort(0);
|
||||
auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators();
|
||||
|
||||
NodeConfig config;
|
||||
|
||||
if (!getParentEdges().empty()) {
|
||||
PortConfig inPortConfig;
|
||||
|
||||
inPortConfig.inPlace(-1);
|
||||
inPortConfig.constant(false);
|
||||
inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
|
||||
|
||||
config.inConfs.push_back(std::move(inPortConfig));
|
||||
}
|
||||
|
||||
PortConfig outPortConfig;
|
||||
|
||||
outPortConfig.inPlace(0);
|
||||
outPortConfig.constant(false);
|
||||
outPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
|
||||
|
||||
config.outConfs.push_back(std::move(outPortConfig));
|
||||
|
||||
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
|
||||
}
|
||||
|
||||
void MemoryInput::initOptimalPrimitiveDescriptor() {
|
||||
// Mimic the child node memory desc to avoid extra reorder
|
||||
auto childEdge = getChildEdgeAt(0);
|
||||
auto child = childEdge->getChild();
|
||||
auto childPd = child->getSelectedPrimitiveDescriptor();
|
||||
OPENVINO_ASSERT(childPd,
|
||||
child->getTypeStr(), " ",
|
||||
child->getName(),
|
||||
"failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
|
||||
|
||||
const auto& childConfig = childPd->getConfig();
|
||||
auto mem_desc = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc();
|
||||
|
||||
auto selectedPd = getSelectedPrimitiveDescriptor();
|
||||
OPENVINO_ASSERT(selectedPd,
|
||||
"MemoryInput ",
|
||||
getName(),
|
||||
" failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
|
||||
|
||||
auto config = selectedPd->getConfig();
|
||||
config.outConfs.front().setMemDesc(mem_desc);
|
||||
//bypass any checks, we enforce the child descriptor
|
||||
selectedPd->setConfig(config);
|
||||
}
|
||||
|
||||
void MemoryInput::resolveInPlaceEdges(Edge::LOOK look) {
|
||||
void MemoryInputBase::resolveInPlaceEdges(Edge::LOOK look) {
|
||||
if (!(look & Edge::LOOK_UP)) {
|
||||
Node::resolveInPlaceEdges(look);
|
||||
return;
|
||||
@ -324,17 +295,17 @@ void MemoryInput::resolveInPlaceEdges(Edge::LOOK look) {
|
||||
}
|
||||
}
|
||||
|
||||
MemoryInput::~MemoryInput() {
|
||||
MemoryInputBase::~MemoryInputBase() {
|
||||
if (outputNode) { outputNode->deregisterSibling(this); }
|
||||
MemoryNodeVirtualEdge::remove(this, holder);
|
||||
}
|
||||
|
||||
MemoryOutput& MemoryInput::getOutputNode() {
|
||||
MemoryOutput& MemoryInputBase::getOutputNode() {
|
||||
OPENVINO_ASSERT(outputNode, "MemoryOutput ", getName(), " doesn't have sibling input");
|
||||
return *outputNode;
|
||||
}
|
||||
|
||||
void MemoryInput::assignState(MemStatePtr newState) {
|
||||
void MemoryInputBase::assignState(MemStatePtr newState) {
|
||||
assignedMem = newState->input_mem();
|
||||
|
||||
OPENVINO_ASSERT(assignedMem,
|
||||
@ -387,40 +358,18 @@ void MemoryInput::assignState(MemStatePtr newState) {
|
||||
getOutputNode().assignExtMemory(newState->output_mem(), newState->internal_desc());
|
||||
}
|
||||
|
||||
MemStatePtr MemoryInput::makeState() const {
|
||||
// assume ov::Tensor is always dense
|
||||
auto original_desc =
|
||||
std::make_shared<CpuBlockedMemoryDesc>(getOriginalOutputPrecisionAtPort(0), outputShapes.at(0));
|
||||
|
||||
auto mem_desc = getBaseMemDescAtOutputPort(0);
|
||||
const auto& eng = getEngine();
|
||||
|
||||
auto state_name = getId();
|
||||
|
||||
// Remove suffix with pair ID. Internal information.
|
||||
auto suffix_idx = state_name.find("/id=");
|
||||
if (suffix_idx != std::string::npos) {
|
||||
state_name = state_name.substr(0, suffix_idx);
|
||||
}
|
||||
|
||||
return std::make_shared<VariableStateDoubleBuffer>(state_name,
|
||||
[mem_desc, eng](){ return std::make_shared<Memory>(eng, mem_desc); },
|
||||
original_desc,
|
||||
getMemoryPtr());
|
||||
}
|
||||
|
||||
void MemoryInput::registerOutputNode(MemoryOutput* node) {
|
||||
void MemoryInputBase::registerOutputNode(MemoryOutput* node) {
|
||||
if (outputNode == node) { return; }
|
||||
if (outputNode) { outputNode->deregisterSibling(this); }
|
||||
outputNode = node;
|
||||
outputNode->registerInputNode(this);
|
||||
}
|
||||
|
||||
void MemoryInput::deregisterSibling(MemoryNode* node) {
|
||||
void MemoryInputBase::deregisterSibling(MemoryOutput* node) {
|
||||
if (node == outputNode) { outputNode = nullptr; }
|
||||
}
|
||||
|
||||
MemoryNodeVirtualEdge::Holder* MemoryNodeVirtualEdge::registerInput(MemoryInput * node) {
|
||||
MemoryNodeVirtualEdge::Holder* MemoryNodeVirtualEdge::registerInput(MemoryInputBase * node) {
|
||||
std::lock_guard<std::mutex> lock{MemoryNodeVirtualEdge::holderMutex};
|
||||
// in case of output already registered
|
||||
auto& holder = MemoryNodeVirtualEdge::getExisted();
|
||||
@ -441,7 +390,7 @@ MemoryNodeVirtualEdge::Holder* MemoryNodeVirtualEdge::registerOutput(MemoryOutpu
|
||||
auto& holder = MemoryNodeVirtualEdge::getExisted();
|
||||
auto sibling = MemoryNodeVirtualEdge::getByName(holder, node->getId());
|
||||
if (sibling != nullptr) {
|
||||
auto inputNode = dynamic_cast<MemoryInput*>(sibling);
|
||||
auto inputNode = dynamic_cast<MemoryInputBase*>(sibling);
|
||||
OPENVINO_ASSERT(inputNode != nullptr);
|
||||
node->registerInputNode(inputNode);
|
||||
} else {
|
||||
@ -458,6 +407,210 @@ void MemoryNodeVirtualEdge::remove(MemoryNode * node, Holder* holder) {
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
void MemoryInput::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
|
||||
auto&& shape = getOutputShapeAtPort(0);
|
||||
auto precision = getOriginalOutputPrecisionAtPort(0);
|
||||
auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators();
|
||||
|
||||
NodeConfig config;
|
||||
|
||||
if (!getParentEdges().empty()) {
|
||||
PortConfig inPortConfig;
|
||||
|
||||
inPortConfig.inPlace(-1);
|
||||
inPortConfig.constant(false);
|
||||
inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
|
||||
|
||||
config.inConfs.push_back(std::move(inPortConfig));
|
||||
}
|
||||
|
||||
PortConfig outPortConfig;
|
||||
|
||||
outPortConfig.inPlace(0);
|
||||
outPortConfig.constant(false);
|
||||
outPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
|
||||
|
||||
config.outConfs.push_back(std::move(outPortConfig));
|
||||
|
||||
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
|
||||
}
|
||||
|
||||
void MemoryInput::initOptimalPrimitiveDescriptor() {
|
||||
// Mimic the child node memory desc to avoid extra reorder
|
||||
static const Type preferredTypes[] = {
|
||||
Type::ScaledDotProductAttention,
|
||||
Type::MatMul,
|
||||
Type::FullyConnected,
|
||||
Type::Convolution,
|
||||
Type::RNNCell,
|
||||
Type::RNNSeq,
|
||||
Type::Subgraph
|
||||
};
|
||||
|
||||
static const Type skipTypes[] = {
|
||||
Type::ShapeOf
|
||||
};
|
||||
|
||||
auto&& childEdges = getChildEdgesAtPort(0);
|
||||
EdgePtr childEdge = childEdges.front();
|
||||
|
||||
if (childEdges.size() > 1) {
|
||||
// try to prioritize memory desc
|
||||
for (auto&& item : childEdges) {
|
||||
auto itemType = item->getChild()->getType();
|
||||
if (std::any_of(std::begin(skipTypes), std::end(skipTypes), [=](Type type){ return type == itemType; })) {
|
||||
continue;
|
||||
}
|
||||
if (std::any_of(std::begin(preferredTypes),
|
||||
std::end(preferredTypes), [=](Type type){ return type == itemType; })) {
|
||||
childEdge = item;
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
auto child = childEdge->getChild();
|
||||
auto childPd = child->getSelectedPrimitiveDescriptor();
|
||||
OPENVINO_ASSERT(childPd,
|
||||
child->getTypeStr(), " ",
|
||||
child->getName(),
|
||||
"failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
|
||||
|
||||
const auto& childConfig = childPd->getConfig();
|
||||
auto mem_desc = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc();
|
||||
|
||||
auto selectedPd = getSelectedPrimitiveDescriptor();
|
||||
OPENVINO_ASSERT(selectedPd,
|
||||
"MemoryInput ",
|
||||
getName(),
|
||||
" failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
|
||||
|
||||
auto config = selectedPd->getConfig();
|
||||
config.outConfs.front().setMemDesc(mem_desc);
|
||||
//bypass any checks, we enforce the child descriptor
|
||||
selectedPd->setConfig(config);
|
||||
}
|
||||
|
||||
MemStatePtr MemoryInput::makeState() const {
|
||||
// assume ov::Tensor is always dense
|
||||
auto original_desc =
|
||||
std::make_shared<CpuBlockedMemoryDesc>(getOriginalOutputPrecisionAtPort(0), outputShapes.at(0));
|
||||
|
||||
auto mem_desc = getBaseMemDescAtOutputPort(0);
|
||||
const auto& eng = getEngine();
|
||||
|
||||
auto state_name = getId();
|
||||
|
||||
// Remove suffix with pair ID. Internal information.
|
||||
auto suffix_idx = state_name.find("/id=");
|
||||
if (suffix_idx != std::string::npos) {
|
||||
state_name = state_name.substr(0, suffix_idx);
|
||||
}
|
||||
|
||||
return std::make_shared<VariableStateDoubleBuffer>(state_name,
|
||||
std::make_shared<Memory>(eng, mem_desc),
|
||||
std::make_shared<Memory>(eng, mem_desc),
|
||||
original_desc,
|
||||
getMemoryPtr());
|
||||
}
|
||||
|
||||
bool MemoryInput::isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept {
|
||||
return MemoryInputBase::isSupportedOperation(op, errorMessage);
|
||||
}
|
||||
|
||||
void MemoryInputSDPA::initSupportedPrimitiveDescriptors() {
|
||||
if (!supportedPrimitiveDescriptors.empty())
|
||||
return;
|
||||
|
||||
auto&& shape = getOutputShapeAtPort(0);
|
||||
auto precision = getOriginalOutputPrecisionAtPort(0);
|
||||
auto&& descCreators = ov::intel_cpu::BlockedDescCreator::getCommonCreators();
|
||||
NodeConfig config;
|
||||
if (!getParentEdges().empty()) {
|
||||
PortConfig inPortConfig;
|
||||
inPortConfig.inPlace(-1);
|
||||
inPortConfig.constant(false);
|
||||
inPortConfig.setMemDesc(descCreators.at(LayoutType::ncsp)->createSharedDesc(precision, shape));
|
||||
config.inConfs.push_back(std::move(inPortConfig));
|
||||
}
|
||||
|
||||
auto&& childEdges = getChildEdgesAtPort(0);
|
||||
auto itr = std::find_if(childEdges.begin(), childEdges.end(),
|
||||
[](const EdgePtr& edge){ return Type::ScaledDotProductAttention == edge->getChild()->getType(); });
|
||||
|
||||
OPENVINO_ASSERT(itr != childEdges.end(), "MemoryInputSDPA isn't attached to an SDPA node");
|
||||
auto SDPA = (*itr)->getChild();
|
||||
auto childPort = (*itr)->getOutputNum();
|
||||
|
||||
// Since this is a very specialized implementation, lets mimic SDPA precision and set cabd layout
|
||||
precision = SDPA->getOriginalInputPrecisionAtPort(childPort);
|
||||
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});
|
||||
|
||||
PortConfig outPortConfig;
|
||||
outPortConfig.inPlace(0);
|
||||
outPortConfig.constant(false);
|
||||
outPortConfig.setMemDesc(cabdDescCreator.createSharedDesc(precision, shape));
|
||||
config.outConfs.push_back(std::move(outPortConfig));
|
||||
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::unknown);
|
||||
}
|
||||
|
||||
void MemoryInputSDPA::initOptimalPrimitiveDescriptor() {
|
||||
auto&& childEdges = getChildEdgesAtPort(0);
|
||||
auto itr = std::find_if(childEdges.begin(), childEdges.end(),
|
||||
[](const EdgePtr& edge){ return Type::ScaledDotProductAttention == edge->getChild()->getType(); });
|
||||
|
||||
OPENVINO_ASSERT(itr != childEdges.end(), "MemoryInputSDPA isn't attached to an SDPA node");
|
||||
auto childEdge = *itr;
|
||||
auto child = childEdge->getChild();
|
||||
auto childPd = child->getSelectedPrimitiveDescriptor();
|
||||
OPENVINO_ASSERT(childPd,
|
||||
child->getTypeStr(), " ",
|
||||
child->getName(),
|
||||
"failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set");
|
||||
|
||||
const auto& childConfig = childPd->getConfig();
|
||||
auto childPrecision = childConfig.inConfs[childEdge->getOutputNum()].getMemDesc()->getPrecision();
|
||||
|
||||
auto selectedPd = getSelectedPrimitiveDescriptor();
|
||||
OPENVINO_ASSERT(selectedPd,
|
||||
"MemoryInputSDPA ",
|
||||
getName(),
|
||||
" failed initOptimalPrimitiveDescriptor() call, preferable primitive descriptor is not set");
|
||||
|
||||
auto config = selectedPd->getConfig();
|
||||
auto memDesc = config.outConfs.front().getMemDesc();
|
||||
auto newMemDesc = memDesc->cloneWithNewPrecision(childPrecision);
|
||||
config.outConfs.front().setMemDesc(newMemDesc);
|
||||
//bypass any checks, we enforce the child descriptor precision
|
||||
selectedPd->setConfig(config);
|
||||
}
|
||||
|
||||
MemStatePtr MemoryInputSDPA::makeState() const {
|
||||
// assume ov::Tensor is always dense
|
||||
auto original_desc =
|
||||
std::make_shared<CpuBlockedMemoryDesc>(getOriginalOutputPrecisionAtPort(0), outputShapes.at(0));
|
||||
|
||||
auto mem_desc = getBaseMemDescAtOutputPort(0);
|
||||
const auto& eng = getEngine();
|
||||
|
||||
auto state_name = getId();
|
||||
|
||||
// Remove suffix with pair ID. Internal information.
|
||||
auto suffix_idx = state_name.find("/id=");
|
||||
if (suffix_idx != std::string::npos) {
|
||||
state_name = state_name.substr(0, suffix_idx);
|
||||
}
|
||||
|
||||
return std::make_shared<VariableStateSingleBuffer>(state_name,
|
||||
std::make_shared<Memory>(eng, mem_desc, std::make_shared<DnnlMemoryMngr>(make_unique<MemoryMngrRealloc>())),
|
||||
original_desc,
|
||||
getMemoryPtr());
|
||||
}
|
||||
|
||||
} // namespace node
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "ie_algorithm.hpp"
|
||||
#include "input.h"
|
||||
#include <node.h>
|
||||
#include <ov_optional.hpp>
|
||||
#include <memory_state.h>
|
||||
#include <proxy_mem_mgr.h>
|
||||
#include <string>
|
||||
@ -20,20 +21,24 @@ namespace intel_cpu {
|
||||
namespace node {
|
||||
|
||||
class MemoryOutput;
|
||||
class MemoryInput;
|
||||
class MemoryInputBase;
|
||||
|
||||
class MemoryNode { //TODO , segregate interfaces
|
||||
std::string _id;
|
||||
class MemoryNode {
|
||||
public:
|
||||
explicit MemoryNode(std::string id) : _id(id) {}
|
||||
explicit MemoryNode(std::string id) : m_id(id) {}
|
||||
explicit MemoryNode(const std::shared_ptr<ov::Node>& op);
|
||||
virtual ~MemoryNode() = default;
|
||||
std::string getId() const {
|
||||
return _id;
|
||||
return m_id;
|
||||
}
|
||||
virtual void registerInputNode(MemoryInput*) = 0;
|
||||
virtual void registerOutputNode(MemoryOutput*) = 0;
|
||||
virtual void deregisterSibling(MemoryNode*) = 0;
|
||||
|
||||
private:
|
||||
std::string m_id;
|
||||
};
|
||||
|
||||
class MemoryStateNode : public MemoryNode {
|
||||
public:
|
||||
using MemoryNode::MemoryNode;
|
||||
virtual void assignState(MemStatePtr newState) = 0;
|
||||
virtual MemStatePtr makeState() const = 0;
|
||||
};
|
||||
@ -60,7 +65,7 @@ public:
|
||||
}
|
||||
|
||||
static Holder* registerOutput(MemoryOutput * node);
|
||||
static Holder* registerInput(MemoryInput * node);
|
||||
static Holder* registerInput(MemoryInputBase * node);
|
||||
static void remove(MemoryNode * node, Holder* holder);
|
||||
static std::mutex holderMutex;
|
||||
};
|
||||
@ -81,12 +86,8 @@ public:
|
||||
}
|
||||
void resolveInPlaceEdges(Edge::LOOK look) override;
|
||||
|
||||
void registerInputNode(MemoryInput* node) override;
|
||||
void registerOutputNode(MemoryOutput* node) override {
|
||||
OPENVINO_THROW("MemoryOutput node has no MemoryOutput type sibling!");
|
||||
}
|
||||
|
||||
void deregisterSibling(MemoryNode* node) override;
|
||||
void registerInputNode(MemoryInputBase* node);
|
||||
void deregisterSibling(MemoryInputBase* node);
|
||||
|
||||
bool needShapeInfer() const override { return false; }
|
||||
bool needPrepareParams() const override { return false; }
|
||||
@ -94,38 +95,38 @@ public:
|
||||
void assignExtMemory(const MemoryPtr& mem, const MemoryDescPtr& memDesc);
|
||||
|
||||
private:
|
||||
MemoryInput& getInputNode();
|
||||
void assignState(MemStatePtr newState) override {
|
||||
OPENVINO_THROW("Unexpected MemoryOutput::assignState call"); //TODO , segregate interfaces
|
||||
}
|
||||
MemStatePtr makeState() const override {
|
||||
OPENVINO_THROW("Unexpected MemoryOutput::makeState call"); //TODO , segregate interfaces
|
||||
}
|
||||
MemoryInputBase& getInputNode();
|
||||
|
||||
private:
|
||||
/**
|
||||
* @brief keeps reference to input sibling node
|
||||
*/
|
||||
MemoryInput* inputNode = nullptr;
|
||||
MemoryInputBase* inputNode = nullptr;
|
||||
MemoryPtr assignedMem = nullptr;
|
||||
MemoryDescPtr extMemDesc = nullptr; // used for resize
|
||||
MemoryNodeVirtualEdge::Holder* holder = nullptr;
|
||||
ProxyMemoryMngrPtr memMngr = nullptr;
|
||||
};
|
||||
|
||||
class MemoryInput : public Input, public MemoryNode {
|
||||
class MemoryInputBase : public Input, public MemoryStateNode {
|
||||
public:
|
||||
MemoryInput(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);
|
||||
~MemoryInput() override;
|
||||
MemoryInputBase(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context);
|
||||
MemoryInputBase(const std::string id,
|
||||
const std::string& name,
|
||||
const std::string& type,
|
||||
const Shape& output_shape,
|
||||
const ov::element::Type& output_prc,
|
||||
const GraphContext::CPtr context,
|
||||
const ov::optional<Shape>& input_shape,
|
||||
const ov::optional<ov::element::Type>& input_prc);
|
||||
|
||||
~MemoryInputBase() override;
|
||||
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
|
||||
bool created() const override {
|
||||
return getType() == Type::MemoryInput;
|
||||
}
|
||||
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void initOptimalPrimitiveDescriptor() override;
|
||||
|
||||
void execute(dnnl::stream strm) override {/*pass*/}
|
||||
void executeDynamicImpl(dnnl::stream strm) override {/*pass*/}
|
||||
|
||||
@ -133,17 +134,11 @@ public:
|
||||
|
||||
void resolveInPlaceEdges(Edge::LOOK look) override;
|
||||
|
||||
void registerInputNode(MemoryInput* node) override {
|
||||
OPENVINO_THROW("MemoryInput node has no MemoryInput type sibling!");
|
||||
}
|
||||
|
||||
void registerOutputNode(MemoryOutput* node) override;
|
||||
void deregisterSibling(MemoryNode* node) override;
|
||||
void registerOutputNode(MemoryOutput* node);
|
||||
void deregisterSibling(MemoryOutput* node);
|
||||
|
||||
// May be extracted to some interface when necessary
|
||||
void assignState(MemStatePtr newState) override;
|
||||
MemStatePtr makeState() const override;
|
||||
|
||||
private:
|
||||
MemoryOutput& getOutputNode();
|
||||
|
||||
private:
|
||||
@ -156,6 +151,27 @@ private:
|
||||
ProxyMemoryMngrPtr memMngr = nullptr;
|
||||
};
|
||||
|
||||
class MemoryInput : public MemoryInputBase {
|
||||
public:
|
||||
using MemoryInputBase::MemoryInputBase;
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void initOptimalPrimitiveDescriptor() override;
|
||||
|
||||
MemStatePtr makeState() const override;
|
||||
};
|
||||
|
||||
class MemoryInputSDPA : public MemoryInputBase {
|
||||
public:
|
||||
using MemoryInputBase::MemoryInputBase;
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ov::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void initOptimalPrimitiveDescriptor() override;
|
||||
|
||||
MemStatePtr makeState() const override;
|
||||
};
|
||||
} // namespace node
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
||||
|
@ -15,11 +15,12 @@
|
||||
#include <shape_inference/shape_inference_internal_dyn.hpp>
|
||||
#include <vector>
|
||||
|
||||
#include "common/cpu_memcpy.h"
|
||||
#include "openvino/core/parallel.hpp"
|
||||
#include "memory_desc/cpu_memory_desc_utils.h"
|
||||
#include "memory_desc/dnnl_blocked_memory_desc.h"
|
||||
#include "utils/plain_tensor.hpp"
|
||||
#include <openvino/op/scaled_dot_product_attention.hpp>
|
||||
#include "common/arbitrary_order_desc_creator.h"
|
||||
|
||||
#ifdef OV_CPU_WITH_MLAS
|
||||
# include "mlas/sgemm.hpp"
|
||||
@ -576,48 +577,93 @@ struct MHASingleToken {
|
||||
|
||||
template <ScaledDotProductAttention::KernelTypes KType, typename T>
|
||||
struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAttention::Executor {
|
||||
PlainTensor<T> q_input; // f32[B, L1, H*S] / [B, H, L1, S]
|
||||
PlainTensor<T> k_input; // f32[B, L1, H*S]
|
||||
PlainTensor<T> v_input; // f32[B, L1, H*S]
|
||||
PlainTensor<T> k_cache; // f32[B, H, max_kvLen, S]
|
||||
PlainTensor<T> v_cache; // f32[B, H, max_kvLen, S]
|
||||
PlainTensor<T> q_input; // f32[B, H, L1, S]
|
||||
PlainTensor<T> k_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S]
|
||||
PlainTensor<T> v_input; // f32[B, H|1, L1, S] / [B, H|1, L0+L1, S]
|
||||
PlainTensor<int32_t> beam_table; // i32[B, max_kvLen]
|
||||
PlainTensor<float> attn_mask; // f32[B, qLen + kvLen]
|
||||
float scale_input = 0.0f; // f32[B, qLen + kvLen]
|
||||
PlainTensor<float> cos_tab; // f32[max_kv_len, rotary_dims//2]
|
||||
PlainTensor<float> sin_tab; // f32[max_kv_len, rotary_dims//2]
|
||||
|
||||
PlainTensor<T> output_emb; // f32[B, L1, H*S]
|
||||
PlainTensor<float> attn_buf; // f32[[B|1],[H|1], L1|1, L0+L1]
|
||||
float scale_input = 0.0f;
|
||||
|
||||
MHAKernel<KType, T> kernel;
|
||||
MHASingleToken<T> kernel_single_token;
|
||||
|
||||
PlainTensor<T> m_query_emb; // query with RoPE position embedding
|
||||
size_t B, H, L1, L0, S;
|
||||
|
||||
ScaledDotProductAttention::Config config;
|
||||
AttentionExecutor(const ScaledDotProductAttention::Config& _config) : config(_config) {}
|
||||
Config config;
|
||||
AttentionExecutor(const Config& _config) : attn_buf(true), config(_config) {}
|
||||
|
||||
void prepare_attn_mask(MemoryPtr attn_input) {
|
||||
attn_mask.resize(attn_input->getStaticDims());
|
||||
attn_buf.resize(attn_input->getStaticDims());
|
||||
auto p = reinterpret_cast<uint8_t*>(attn_input->getData());
|
||||
for (size_t i = 0; i < attn_input->getSize(); i++)
|
||||
attn_mask.data()[i] = p[i] ? 0.0f : -FLT_MAX;
|
||||
attn_buf.data()[i] = p[i] ? 0.0f : -FLT_MAX;
|
||||
}
|
||||
|
||||
void concat_pastkv(const std::vector<MemoryPtr>& inputs,
|
||||
const std::vector<MemoryPtr>& outputs,
|
||||
const PlainTensor<T>& k_input,
|
||||
const PlainTensor<T>& v_input,
|
||||
PlainTensor<T>& past_k_output,
|
||||
PlainTensor<T>& past_v_output) {
|
||||
if (config.config.fuse_concat) {
|
||||
k_input.assert_dims({B, 0, L1, S}, true);
|
||||
v_input.assert_dims({B, 0, L1, S}, true);
|
||||
auto past_k_idx = inputs.size() - 2;
|
||||
auto past_k_mem = inputs[past_k_idx + 0];
|
||||
L0 = past_k_mem->getStaticDims()[2];
|
||||
// k,v may support multiquery
|
||||
auto Hk = past_k_mem->getStaticDims()[1];
|
||||
// [B, H, L0, S]
|
||||
past_k_output.reset(outputs[1]);
|
||||
past_v_output.reset(outputs[2]);
|
||||
parallel_for3d(B, Hk, L1, [&](size_t b, size_t h, size_t m) {
|
||||
std::memcpy(&past_k_output.at({b, h, m + L0, 0}),
|
||||
&k_input.at({b, h, m, 0}),
|
||||
S * sizeof(T));
|
||||
std::memcpy(&past_v_output.at({b, h, m + L0, 0}),
|
||||
&v_input.at({b, h, m, 0}),
|
||||
S * sizeof(T));
|
||||
});
|
||||
if (!config.is_concat_inplaced) {
|
||||
PlainTensor<T> past_k_input, past_v_input;
|
||||
past_k_input.reset(past_k_mem);
|
||||
past_v_input.reset(inputs[past_k_idx + 1]);
|
||||
parallel_for3d(B, Hk, L0, [&](size_t b, size_t h, size_t m) {
|
||||
std::memcpy(&past_k_output.at({b, h, m, 0}),
|
||||
&past_k_input.at({b, h, m, 0}),
|
||||
S * sizeof(T));
|
||||
std::memcpy(&past_v_output.at({b, h, m, 0}),
|
||||
&past_v_input.at({b, h, m, 0}),
|
||||
S * sizeof(T));
|
||||
});
|
||||
}
|
||||
} else {
|
||||
// k,v inputs are already concatenated
|
||||
L0 = k_input.size(2) - L1;
|
||||
k_input.assert_dims({B, 0, L0 + L1, S}, true);
|
||||
v_input.assert_dims({B, 0, L0 + L1, S}, true);
|
||||
past_k_output = k_input;
|
||||
past_v_output = v_input;
|
||||
}
|
||||
}
|
||||
|
||||
void execute(dnnl::stream strm, const std::vector<MemoryPtr>& inputs, const std::vector<MemoryPtr>& outputs) override {
|
||||
bool has_out_transpose = config.output_BLHxS;
|
||||
bool fuse_causal_attn = config.fuse_causal_attn;
|
||||
bool is_causal = config.is_causal;
|
||||
auto input_num = inputs.size();
|
||||
bool has_out_transpose = config.config.output_BLHxS;
|
||||
bool fuse_causal_attn = config.config.fuse_causal_attn;
|
||||
bool is_causal = config.config.is_causal;
|
||||
const bool fuse_concat = config.config.fuse_concat;
|
||||
auto input_num = inputs.size() - (fuse_concat ? 2 : 0);
|
||||
|
||||
q_input.reset(inputs[0]);
|
||||
k_input.reset(inputs[1]);
|
||||
v_input.reset(inputs[2]);
|
||||
PlainTensor<float> attn_mask;
|
||||
if (input_num > 3) {
|
||||
// attn_mask
|
||||
if (inputs[3]->getDesc().getPrecision() == ov::element::u8) {
|
||||
// bool->f32
|
||||
prepare_attn_mask(inputs[3]);
|
||||
attn_mask = attn_buf;
|
||||
} else {
|
||||
attn_mask.reset(inputs[3]);
|
||||
}
|
||||
@ -627,30 +673,22 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
|
||||
}
|
||||
}
|
||||
|
||||
size_t B, H, L1, L0, S;
|
||||
|
||||
// q, k, v: [B, H, L1, S]
|
||||
// q: [B, H, L1, S]
|
||||
B = q_input.size(0);
|
||||
H = q_input.size(1);
|
||||
L1 = q_input.size(2);
|
||||
L0 = k_input.size(2) - L1;
|
||||
S = q_input.size(-1);
|
||||
|
||||
ov::intel_cpu::PlainTensor<T> output_emb(outputs[0]);
|
||||
PlainTensor<T> present_key, present_value;
|
||||
concat_pastkv(inputs, outputs, k_input, v_input, present_key, present_value);
|
||||
|
||||
q_input.assert_dims({B, H, L1, S});
|
||||
k_input.assert_dims({B, 0, L0 + L1, S}, true);
|
||||
v_input.assert_dims({B, 0, L0 + L1, S}, true);
|
||||
m_query_emb = q_input;
|
||||
present_key = k_input;
|
||||
present_value = v_input;
|
||||
ov::intel_cpu::PlainTensor<T> output_emb(outputs[0]);
|
||||
|
||||
bool auto_causal;
|
||||
bool use_attn_mask;
|
||||
if (fuse_causal_attn) {
|
||||
assert(attn_mask);
|
||||
attn_mask.assert_dims({B, 1, 1, L0 + L1});
|
||||
attn_mask.assert_dims({B, 1, L1, L0 + L1});
|
||||
auto_causal = true;
|
||||
use_attn_mask = true;
|
||||
} else {
|
||||
@ -677,7 +715,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
|
||||
|
||||
if (L1 > 1) {
|
||||
// multi-token version
|
||||
kernel(strm, m_query_emb, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
|
||||
kernel(strm, q_input, k_input, v_input, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
|
||||
output_emb, has_out_transpose, auto_causal, scale_input);
|
||||
} else {
|
||||
// 1-token version
|
||||
@ -685,7 +723,7 @@ struct ScaledDotProductAttention::AttentionExecutor : public ScaledDotProductAtt
|
||||
// 1, in matrix mutiply, using AMX is not efficency because the M dimension of A will alway be 1
|
||||
// 2, using float will save the repack cost which typically is required for bf16/int8 opt
|
||||
// 3, using dot product can leverage the SIMD while easily adapt to indirect kv cache
|
||||
kernel_single_token(m_query_emb, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
|
||||
kernel_single_token(q_input, present_key, present_value, {}, use_attn_mask ? attn_mask : PlainTensor<float>(),
|
||||
output_emb, beam_table, has_out_transpose, auto_causal, scale_input);
|
||||
}
|
||||
}
|
||||
@ -700,9 +738,10 @@ ScaledDotProductAttention::ScaledDotProductAttention(const std::shared_ptr<ngrap
|
||||
|
||||
const auto node = std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op);
|
||||
if (node) {
|
||||
m_config.is_causal = node->get_causal();
|
||||
m_config.config.is_causal = node->get_causal();
|
||||
} else {
|
||||
OPENVINO_THROW("CPU: cast to v13::ScaledDotProductAttention failed.");
|
||||
const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionStub>(op);
|
||||
m_config.config = node->get_config();
|
||||
}
|
||||
}
|
||||
|
||||
@ -711,6 +750,80 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
|
||||
return;
|
||||
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
|
||||
|
||||
NodeConfig config;
|
||||
auto& creatorsMap = BlockedDescCreator::getCommonCreators();
|
||||
auto orginSDPInputNumber = getOriginalInputsNumber() - (m_config.config.fuse_concat ? 2 : 0);
|
||||
config.inConfs.resize(getOriginalInputsNumber());
|
||||
config.outConfs.resize(getOriginalOutputsNumber());
|
||||
config.inConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
|
||||
rtPrecision, getInputShapeAtPort(0)));
|
||||
config.inConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
|
||||
rtPrecision, getInputShapeAtPort(1)));
|
||||
config.inConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
|
||||
rtPrecision, getInputShapeAtPort(2)));
|
||||
auto nextPortIdx = 3;
|
||||
if (orginSDPInputNumber > 3) {
|
||||
// attn_mask
|
||||
if (getOriginalInputPrecisionAtPort(nextPortIdx) == ov::element::u8) {
|
||||
config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
|
||||
ov::element::u8, getInputShapeAtPort(nextPortIdx)));
|
||||
} else {
|
||||
config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
|
||||
ov::element::f32, getInputShapeAtPort(nextPortIdx)));
|
||||
}
|
||||
nextPortIdx++;
|
||||
}
|
||||
if (orginSDPInputNumber > 4) {
|
||||
config.inConfs[nextPortIdx].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
|
||||
ov::element::f32, getInputShapeAtPort(nextPortIdx)));
|
||||
}
|
||||
|
||||
if (m_config.config.fuse_concat) {
|
||||
ArbitraryOrderDescCreator cabdDescCreator({2, 0, 1, 3});
|
||||
|
||||
config.inConfs[orginSDPInputNumber + 0].setMemDesc(cabdDescCreator.createSharedDesc(
|
||||
rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 0)));
|
||||
config.inConfs[orginSDPInputNumber + 1].setMemDesc(cabdDescCreator.createSharedDesc(
|
||||
rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 1)));
|
||||
|
||||
config.outConfs[1].setMemDesc(cabdDescCreator.createSharedDesc(
|
||||
rtPrecision, getOutputShapeAtPort(1)));
|
||||
config.outConfs[1].inPlace(orginSDPInputNumber + 0);
|
||||
config.outConfs[2].setMemDesc(cabdDescCreator.createSharedDesc(
|
||||
rtPrecision, getOutputShapeAtPort(2)));
|
||||
config.outConfs[2].inPlace(orginSDPInputNumber + 1);
|
||||
}
|
||||
|
||||
config.outConfs[0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
|
||||
rtPrecision, getOutputShapeAtPort(0)));
|
||||
|
||||
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any);
|
||||
// may fallback to abcd without inplace
|
||||
if (m_config.config.fuse_concat) {
|
||||
config.inConfs[orginSDPInputNumber + 0].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
|
||||
rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 0)));
|
||||
config.inConfs[orginSDPInputNumber + 1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
|
||||
rtPrecision, getInputShapeAtPort(orginSDPInputNumber + 1)));
|
||||
config.outConfs[1].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
|
||||
rtPrecision, getOutputShapeAtPort(1)));
|
||||
config.outConfs[1].inPlace(-1);
|
||||
config.outConfs[2].setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(
|
||||
rtPrecision, getOutputShapeAtPort(2)));
|
||||
config.outConfs[2].inPlace(-1);
|
||||
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref_any);
|
||||
}
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::createPrimitive() {
|
||||
if (m_config.config.fuse_concat) {
|
||||
auto desc = getSelectedPrimitiveDescriptor();
|
||||
if (desc == nullptr)
|
||||
OPENVINO_THROW("has unidentified preferable primitive descriptor");
|
||||
|
||||
m_config.is_concat_inplaced = desc->getConfig().outConfs[1].inPlace() >= 0;
|
||||
}
|
||||
auto rtPrecision = getOriginalInputPrecisionAtPort(0);
|
||||
|
||||
if (rtPrecision == ov::element::bf16) {
|
||||
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, ov::bfloat16>>(m_config);
|
||||
} else {
|
||||
@ -722,29 +835,6 @@ void ScaledDotProductAttention::initSupportedPrimitiveDescriptors() {
|
||||
m_executor = std::make_shared<AttentionExecutor<KT_ONEDNN, float>>(m_config);
|
||||
#endif
|
||||
}
|
||||
|
||||
// initialize input ports
|
||||
std::vector<PortConfigurator> inPortConfigs;
|
||||
inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(0), false, -1);
|
||||
inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(1), false, -1);
|
||||
inPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getInputShapeAtPort(2), false, -1);
|
||||
if (getOriginalInputsNumber() > 3) {
|
||||
// attn_mask
|
||||
if (getOriginalInputPrecisionAtPort(3) == ov::element::u8) {
|
||||
inPortConfigs.emplace_back(LayoutType::ncsp, ov::element::u8, getInputShapeAtPort(3), false, -1);
|
||||
} else {
|
||||
inPortConfigs.emplace_back(LayoutType::ncsp, ov::element::f32, getInputShapeAtPort(3), false, -1);
|
||||
}
|
||||
}
|
||||
if (getOriginalInputsNumber() > 4) {
|
||||
inPortConfigs.emplace_back(LayoutType::ncsp, ov::element::f32, getInputShapeAtPort(4), false, -1);
|
||||
}
|
||||
|
||||
// initialize output port
|
||||
std::vector<PortConfigurator> outPortConfigs;
|
||||
outPortConfigs.emplace_back(LayoutType::ncsp, rtPrecision, getOutputShapeAtPort(0), false, -1);
|
||||
|
||||
addSupportedPrimDesc(inPortConfigs, outPortConfigs, impl_desc_type::ref_any);
|
||||
}
|
||||
|
||||
void ScaledDotProductAttention::execute(dnnl::stream strm) {
|
||||
@ -760,8 +850,9 @@ void ScaledDotProductAttention::execute(dnnl::stream strm) {
|
||||
|
||||
bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept {
|
||||
try {
|
||||
if (!std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op)) {
|
||||
errorMessage = "Only ScaledDotProductAttention operation are supported";
|
||||
if (!std::dynamic_pointer_cast<const ov::op::v13::ScaledDotProductAttention>(op) &&
|
||||
!std::dynamic_pointer_cast<const ScaledDotProductAttentionStub>(op)) {
|
||||
errorMessage = "Only ScaledDotProductAttention or ScaledDotProductAttentionStub operation are supported";
|
||||
return false;
|
||||
}
|
||||
// expect shape of q: [B, H, L, S]
|
||||
@ -770,7 +861,14 @@ bool ScaledDotProductAttention::isSupportedOperation(const std::shared_ptr<const
|
||||
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
|
||||
return false;
|
||||
}
|
||||
if (op->get_input_size() > 3) {
|
||||
int orgSDPAInput = static_cast<int>(op->get_input_size());
|
||||
const auto node = std::dynamic_pointer_cast<const ScaledDotProductAttentionStub>(op);
|
||||
if (node) {
|
||||
if (node->get_config().fuse_concat) {
|
||||
orgSDPAInput -= 2;
|
||||
}
|
||||
}
|
||||
if (orgSDPAInput > 3) {
|
||||
inRank = op->get_input_partial_shape(3).size();
|
||||
if (inRank > 4u) {
|
||||
errorMessage = "Doesn't support 'attention mask' with rank: " + std::to_string(inRank);
|
||||
|
@ -10,6 +10,8 @@
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "transformations/cpu_opset/common/op/sdp.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
namespace node {
|
||||
@ -22,6 +24,10 @@ public:
|
||||
bool created() const override {
|
||||
return getType() == Type::ScaledDotProductAttention;
|
||||
}
|
||||
// pastkv may have zero dimension
|
||||
bool isExecutable() const override {
|
||||
return true;
|
||||
}
|
||||
bool needPrepareParams() const override {
|
||||
return false;
|
||||
}
|
||||
@ -30,6 +36,7 @@ public:
|
||||
}
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void execute(dnnl::stream strm) override;
|
||||
void createPrimitive() override;
|
||||
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||
|
||||
enum KernelTypes { KT_REF, KT_ONEDNN, KT_MLAS};
|
||||
@ -40,9 +47,8 @@ private:
|
||||
};
|
||||
|
||||
struct Config {
|
||||
bool output_BLHxS = false;
|
||||
bool fuse_causal_attn = false;
|
||||
bool is_causal = false;
|
||||
ScaledDotProductAttentionStub::Config config;
|
||||
bool is_concat_inplaced = false;
|
||||
};
|
||||
|
||||
Config m_config;
|
||||
|
@ -51,12 +51,34 @@ void ShapeOf::initSupportedPrimitiveDescriptors() {
|
||||
|
||||
ov::element::Type precision = getOriginalInputPrecisionAtPort(0);
|
||||
|
||||
const LayoutType dataFormats[4] = { LayoutType::ncsp, LayoutType::nspc, LayoutType::nCsp16c, LayoutType::nCsp8c };
|
||||
for (const auto &df : dataFormats) {
|
||||
addSupportedPrimDesc({{df, precision}},
|
||||
{{LayoutType::ncsp, ov::element::i32}},
|
||||
impl_desc_type::ref);
|
||||
}
|
||||
addSupportedPrimDesc({{LayoutType::ncsp, precision}},
|
||||
{{LayoutType::ncsp, ov::element::i32}},
|
||||
impl_desc_type::ref);
|
||||
}
|
||||
|
||||
void ShapeOf::initOptimalPrimitiveDescriptor() {
|
||||
// Mimic the parent node memory desc to avoid extra reorder
|
||||
auto parentEdge = getParentEdgeAt(0);
|
||||
auto parent = parentEdge->getParent();
|
||||
auto parentPd = parent->getSelectedPrimitiveDescriptor();
|
||||
OPENVINO_ASSERT(parentPd,
|
||||
parent->getTypeStr(), " ",
|
||||
parent->getName(),
|
||||
"failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
|
||||
|
||||
const auto& parentConfig = parentPd->getConfig();
|
||||
auto mem_desc = parentConfig.outConfs[parentEdge->getInputNum()].getMemDesc();
|
||||
|
||||
auto selected_pd = getSelectedPrimitiveDescriptor();
|
||||
OPENVINO_ASSERT(selected_pd,
|
||||
"ShapeOf ",
|
||||
getName(),
|
||||
" failed getSelectedPrimitiveDescriptor() call, preferable primitive descriptor is not set");
|
||||
|
||||
auto config = selected_pd->getConfig();
|
||||
config.inConfs.front().setMemDesc(mem_desc);
|
||||
//bypass any checks, we enforce the parent descriptor
|
||||
selected_pd->setConfig(config);
|
||||
}
|
||||
|
||||
bool ShapeOf::isExecutable() const {
|
||||
@ -66,12 +88,12 @@ bool ShapeOf::isExecutable() const {
|
||||
void ShapeOf::execute(dnnl::stream strm) {
|
||||
auto inPtr = getParentEdgeAt(0)->getMemoryPtr();
|
||||
auto outPtr = getChildEdgeAt(0)->getMemoryPtr();
|
||||
auto inDims = inPtr->getStaticDims();
|
||||
auto&& inDims = inPtr->getStaticDims();
|
||||
size_t dimsCount = inDims.size();
|
||||
if (outPtr->getStaticDims().size() != 1 || dimsCount != outPtr->getStaticDims()[0])
|
||||
OPENVINO_THROW(errorPrefix, "has inconsistent input shape and output size");
|
||||
|
||||
auto *dst = reinterpret_cast<int *>(getChildEdgeAt(0)->getMemoryPtr()->getData());
|
||||
auto* dst = reinterpret_cast<int *>(outPtr->getData());
|
||||
|
||||
for (size_t i = 0; i < dimsCount; i++) {
|
||||
dst[i] = inDims[i];
|
||||
|
@ -20,6 +20,7 @@ public:
|
||||
|
||||
void getSupportedDescriptors() override;
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void initOptimalPrimitiveDescriptor() override;
|
||||
void execute(dnnl::stream strm) override;
|
||||
bool created() const override;
|
||||
bool needPrepareParams() const override {return false;};
|
||||
|
@ -0,0 +1,57 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "sdp.hpp"
|
||||
|
||||
#include <algorithm>
|
||||
|
||||
#include "transformations/itt.hpp"
|
||||
|
||||
ov::intel_cpu::ScaledDotProductAttentionStub::ScaledDotProductAttentionStub(const OutputVector& args, const Config& cfg)
|
||||
: Op(args),
|
||||
m_config(cfg) {
|
||||
constructor_validate_and_infer_types();
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::Node> ov::intel_cpu::ScaledDotProductAttentionStub::clone_with_new_inputs(
|
||||
const ov::OutputVector& new_args) const {
|
||||
INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_with_new_inputs);
|
||||
check_new_args_count(this, new_args);
|
||||
return std::make_shared<ov::intel_cpu::ScaledDotProductAttentionStub>(new_args, m_config);
|
||||
}
|
||||
|
||||
void ov::intel_cpu::ScaledDotProductAttentionStub::validate_and_infer_types() {
|
||||
INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_validate_and_infer_types);
|
||||
auto input_num = get_input_size();
|
||||
// [B, H, L1, S]
|
||||
auto q_ps = get_input_partial_shape(0);
|
||||
// [B, H, L0, S]
|
||||
auto past_kv_ps = get_input_partial_shape(input_num - 1);
|
||||
|
||||
NODE_VALIDATION_CHECK(this, m_config.output_BLHxS == false);
|
||||
NODE_VALIDATION_CHECK(this, q_ps.size() >= 3);
|
||||
if (past_kv_ps.rank().is_static()) {
|
||||
NODE_VALIDATION_CHECK(this, q_ps.size() == past_kv_ps.size());
|
||||
for (size_t i = 0; i < q_ps.size(); i++) {
|
||||
if (i == q_ps.size() - 2)
|
||||
continue;
|
||||
NODE_VALIDATION_CHECK(this, q_ps[i].compatible(past_kv_ps[i]));
|
||||
}
|
||||
past_kv_ps[q_ps.size() - 2] += q_ps[q_ps.size() - 2];
|
||||
}
|
||||
set_output_type(0, get_input_element_type(0), q_ps);
|
||||
set_output_type(1, get_input_element_type(input_num - 1), past_kv_ps);
|
||||
set_output_type(2, get_input_element_type(input_num - 1), past_kv_ps);
|
||||
}
|
||||
|
||||
bool ov::intel_cpu::ScaledDotProductAttentionStub::visit_attributes(ov::AttributeVisitor& visitor) {
|
||||
INTERNAL_OP_SCOPE(ScaledDotProductAttentionStub_visit_attributes);
|
||||
visitor.start_structure("config");
|
||||
visitor.on_attribute("output_BLHxS", m_config.output_BLHxS);
|
||||
visitor.on_attribute("fuse_causal_attn", m_config.fuse_causal_attn);
|
||||
visitor.on_attribute("is_causal", m_config.is_causal);
|
||||
visitor.on_attribute("fuse_concat", m_config.fuse_concat);
|
||||
visitor.finish_structure();
|
||||
return true;
|
||||
}
|
@ -0,0 +1,50 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "openvino/op/op.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
/// \brief Scaled dot product attention from PyTorch, fused with Concat
|
||||
///
|
||||
/// \ingroup ov_ops_cpp_api
|
||||
|
||||
class ScaledDotProductAttentionStub : public ov::op::Op {
|
||||
public:
|
||||
OPENVINO_OP("ScaledDotProductAttentionStub", "cpu_plugin_opset");
|
||||
|
||||
ScaledDotProductAttentionStub() = default;
|
||||
|
||||
struct Config {
|
||||
bool output_BLHxS = false; // true implies that output is [B,L,H*S]
|
||||
|
||||
bool fuse_causal_attn = false; // fuse causal mask and attn mask into attn_mask
|
||||
bool is_causal = false; // apply causal mask internally
|
||||
bool fuse_concat = false; // fuse (concat->sdp) ==> sdp
|
||||
};
|
||||
|
||||
ScaledDotProductAttentionStub(const OutputVector& args, const Config& cfg);
|
||||
|
||||
std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& new_args) const override;
|
||||
bool visit_attributes(AttributeVisitor& visitor) override;
|
||||
void validate_and_infer_types() override;
|
||||
|
||||
const Config& get_config() const {
|
||||
return m_config;
|
||||
}
|
||||
|
||||
Config& get_config() {
|
||||
return m_config;
|
||||
}
|
||||
|
||||
private:
|
||||
Config m_config;
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -0,0 +1,121 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "stateful_sdp_fusion.hpp"
|
||||
|
||||
#include <cstdint>
|
||||
#include <limits>
|
||||
#include <openvino/core/rt_info.hpp>
|
||||
#include <openvino/opsets/opset1.hpp>
|
||||
#include <openvino/opsets/opset13.hpp>
|
||||
#include <openvino/opsets/opset6.hpp>
|
||||
#include <openvino/opsets/opset8.hpp>
|
||||
#include <openvino/pass/pattern/op/or.hpp>
|
||||
#include <openvino/pass/pattern/op/wrap_type.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "ov_ops/type_relaxed.hpp"
|
||||
#include "transformations/cpu_opset/common/op/sdp.hpp"
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
|
||||
StatefulSDPFusion::StatefulSDPFusion() {
|
||||
MATCHER_SCOPE(StatefulSDPFusion);
|
||||
using namespace ov::pass::pattern;
|
||||
|
||||
auto past_k = wrap_type<opset6::ReadValue>();
|
||||
auto past_v = wrap_type<opset6::ReadValue>();
|
||||
auto convert_past_k = wrap_type<opset1::Convert>({past_k});
|
||||
auto convert_past_v = wrap_type<opset1::Convert>({past_v});
|
||||
auto concat_input_k = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past_k, convert_past_k});
|
||||
auto concat_input_v = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{past_v, convert_past_v});
|
||||
auto concat_k = wrap_type<opset6::Concat>({concat_input_k, any_input()});
|
||||
auto concat_v = wrap_type<opset6::Concat>({concat_input_v, any_input()});
|
||||
auto sdp0 = wrap_type<opset13::ScaledDotProductAttention>({any_input(), concat_k, concat_v});
|
||||
auto sdp1 = wrap_type<opset13::ScaledDotProductAttention>({any_input(), concat_k, concat_v, any_input()});
|
||||
auto sdp2 = wrap_type<opset13::ScaledDotProductAttention>({any_input(), concat_k, concat_v, any_input(), any_input()});
|
||||
auto sdp = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{sdp0, sdp1, sdp2});
|
||||
|
||||
ov::matcher_pass_callback callback = [=](Matcher& m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
auto root = m.get_match_root();
|
||||
|
||||
auto find_assign = [&](const ov::Output<ov::Node>& out, opset6::Assign*& assign, opset1::Convert*& cvt) {
|
||||
auto present_to = out.get_target_inputs();
|
||||
if (present_to.size() != 2)
|
||||
return;
|
||||
for (auto& to : present_to) {
|
||||
auto to_node = to.get_node();
|
||||
if (auto convert = dynamic_cast<opset1::Convert*>(to_node)) {
|
||||
auto cvt_targets = convert->get_output_target_inputs(0);
|
||||
if (cvt_targets.size() == 1) {
|
||||
to_node = cvt_targets.begin()->get_node();
|
||||
cvt = convert;
|
||||
}
|
||||
}
|
||||
assign = dynamic_cast<opset6::Assign*>(to_node);
|
||||
if (assign)
|
||||
return;
|
||||
}
|
||||
};
|
||||
|
||||
std::shared_ptr<opset1::Convert> read_cvt_k_node, read_cvt_v_node;
|
||||
const auto sdp_node = ov::as_type_ptr<opset13::ScaledDotProductAttention>(root);
|
||||
const auto past_k_node = ov::as_type_ptr<opset6::ReadValue>(pattern_map.at(past_k).get_node_shared_ptr());
|
||||
const auto past_v_node = ov::as_type_ptr<opset6::ReadValue>(pattern_map.at(past_v).get_node_shared_ptr());
|
||||
const auto concat_k_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_k).get_node_shared_ptr());
|
||||
const auto concat_v_node = ov::as_type_ptr<opset6::Concat>(pattern_map.at(concat_v).get_node_shared_ptr());
|
||||
if (pattern_map.count(convert_past_k)) {
|
||||
read_cvt_k_node = ov::as_type_ptr<opset1::Convert>(pattern_map.at(convert_past_k).get_node_shared_ptr());
|
||||
read_cvt_v_node = ov::as_type_ptr<opset1::Convert>(pattern_map.at(convert_past_v).get_node_shared_ptr());
|
||||
}
|
||||
opset6::Assign* assign_k_node = nullptr, *assign_v_node = nullptr;
|
||||
opset1::Convert* assign_cvt_k_node = nullptr, *assign_cvt_v_node = nullptr;
|
||||
find_assign(concat_k_node, assign_k_node, assign_cvt_k_node);
|
||||
if (!assign_k_node)
|
||||
return false;
|
||||
if (past_k_node->get_variable_id() != assign_k_node->get_variable_id())
|
||||
return false;
|
||||
|
||||
find_assign(concat_v_node, assign_v_node, assign_cvt_v_node);
|
||||
if (!assign_v_node)
|
||||
return false;
|
||||
if (past_v_node->get_variable_id() != assign_v_node->get_variable_id())
|
||||
return false;
|
||||
|
||||
auto args = sdp_node->input_values();
|
||||
args[1] = concat_k_node->input_value(1);
|
||||
args[2] = concat_v_node->input_value(1);
|
||||
args.push_back(read_cvt_k_node ? read_cvt_k_node->output(0) : past_k_node->output(0));
|
||||
args.push_back(read_cvt_v_node ? read_cvt_v_node->output(0) : past_v_node->output(0));
|
||||
ov::intel_cpu::ScaledDotProductAttentionStub::Config config;
|
||||
|
||||
config.is_causal = sdp_node->get_causal();
|
||||
config.fuse_concat = true;
|
||||
|
||||
auto old_node = sdp_node;
|
||||
auto new_node = std::make_shared<ov::intel_cpu::ScaledDotProductAttentionStub>(args, config);
|
||||
new_node->set_friendly_name(old_node->get_friendly_name());
|
||||
ov::replace_node(old_node, {new_node->output(0)});
|
||||
if (assign_cvt_k_node)
|
||||
assign_cvt_k_node->set_arguments({new_node->output(1)});
|
||||
else
|
||||
assign_k_node->set_arguments({new_node->output(1)});
|
||||
|
||||
if (assign_cvt_v_node)
|
||||
assign_cvt_v_node->set_arguments({new_node->output(2)});
|
||||
else
|
||||
assign_v_node->set_arguments({new_node->output(2)});
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ov::pass::pattern::Matcher>(sdp, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -0,0 +1,18 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ov {
|
||||
namespace intel_cpu {
|
||||
class StatefulSDPFusion : public ov::pass::MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("StatefulSDPFusion", "0");
|
||||
StatefulSDPFusion();
|
||||
};
|
||||
|
||||
} // namespace intel_cpu
|
||||
} // namespace ov
|
@ -112,6 +112,7 @@
|
||||
#include "transformations/cpu_opset/common/pass/move_eltwise_up_data_movement.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/rope_fusion.hpp"
|
||||
#include "transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp"
|
||||
|
||||
// Snippets
|
||||
#include "snippets/pass/tokenization.hpp"
|
||||
@ -660,6 +661,7 @@ void Transformations::PostLpt() {
|
||||
CPU_REGISTER_PASS_X64(postLPTPassManager, EliminateStridedSlice);
|
||||
CPU_REGISTER_PASS_X64(postLPTPassManager, RoPEFusion);
|
||||
|
||||
CPU_REGISTER_PASS_X64(postLPTPassManager, StatefulSDPFusion);
|
||||
postLPTPassManager.run_passes(model);
|
||||
}
|
||||
|
||||
|
@ -169,10 +169,19 @@ struct PlainTensor {
|
||||
}
|
||||
|
||||
void reset(MemoryPtr mem) {
|
||||
assert_dt<DT>(mem->getDesc().getPrecision());
|
||||
const auto& mem_desc = mem->getDesc();
|
||||
assert_dt<DT>(mem_desc.getPrecision());
|
||||
const auto* desc_ptr = mem_desc.as<BlockedMemoryDesc>();
|
||||
// not support block layout
|
||||
OPENVINO_ASSERT(desc_ptr && desc_ptr->getOrder().size() == mem->getStaticDims().size());
|
||||
m_mem = mem;
|
||||
VectorDims strides(desc_ptr->getStrides().size());
|
||||
const auto& orders = desc_ptr->getOrder();
|
||||
for (size_t i = 0; i < orders.size(); i++) {
|
||||
strides[orders[i]] = desc_ptr->getStrides()[i];
|
||||
}
|
||||
// this reshape_to() can do reshape w/o additional cost
|
||||
resize(mem->getStaticDims(), reinterpret_cast<DT*>(mem->getData()));
|
||||
resize(mem->getStaticDims(), reinterpret_cast<DT*>(mem->getData()), &strides);
|
||||
}
|
||||
|
||||
ov::element::Type get_precision(void) {
|
||||
@ -327,14 +336,14 @@ struct PlainTensor {
|
||||
return new_tensor_view;
|
||||
}
|
||||
|
||||
void resize(const VectorDims& new_dims, DT* data = nullptr) {
|
||||
void resize(const VectorDims& new_dims, DT* data = nullptr, const VectorDims* strides = nullptr) {
|
||||
// initialize strides for compact/dense tensor
|
||||
m_rank = new_dims.size();
|
||||
assert(m_rank <= PLAINTENSOR_RANK_MAX);
|
||||
size_t stride = 1;
|
||||
for (int i = m_rank - 1; i >= 0; i--) {
|
||||
m_dims[i] = new_dims[i];
|
||||
m_strides[i] = stride;
|
||||
m_strides[i] = strides ? (*strides)[i] : stride;
|
||||
stride *= new_dims[i];
|
||||
}
|
||||
|
||||
|
@ -72,7 +72,7 @@ protected:
|
||||
for (auto&& shape : inputDynamicShapes)
|
||||
params.push_back(std::make_shared<ov::op::v0::Parameter>(inType, shape));
|
||||
|
||||
auto shapeOf = std::make_shared<ngraph::opset3::ShapeOf>(params[0], ngraph::element::i32);
|
||||
auto shapeOf = std::make_shared<ngraph::opset3::ShapeOf>(params.front(), ngraph::element::i32);
|
||||
|
||||
function = makeNgraphFunction(netPrecision, params, shapeOf, "ShapeOf");
|
||||
}
|
||||
@ -85,29 +85,6 @@ TEST_P(ShapeOfLayerCPUTest, CompareWithRefs) {
|
||||
|
||||
namespace {
|
||||
|
||||
/* CPU PARAMS */
|
||||
std::vector<CPUSpecificParams> getCpuInfoForDimsCount(const size_t dimsCount = 3) {
|
||||
std::vector<CPUSpecificParams> resCPUParams;
|
||||
if (dimsCount == 5) {
|
||||
resCPUParams.push_back(CPUSpecificParams{{nCdhw16c}, {x}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{nCdhw8c}, {x}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{ncdhw}, {x}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{ndhwc}, {x}, {}, {}});
|
||||
} else if (dimsCount == 4) {
|
||||
resCPUParams.push_back(CPUSpecificParams{{nChw16c}, {x}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{nChw8c}, {x}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{nchw}, {x}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{nhwc}, {x}, {}, {}});
|
||||
} else {
|
||||
resCPUParams.push_back(CPUSpecificParams{{nCw16c}, {x}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{nCw8c}, {x}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{abc}, {x}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{acb}, {x}, {}, {}});
|
||||
}
|
||||
|
||||
return resCPUParams;
|
||||
}
|
||||
|
||||
const std::vector<ElementType> netPrecisions = {
|
||||
ElementType::f32,
|
||||
ElementType::bf16,
|
||||
@ -119,17 +96,9 @@ std::vector<ov::test::InputShape> inShapesDynamic3d = {
|
||||
{
|
||||
{-1, -1, -1},
|
||||
{
|
||||
{ 8, 5, 4 },
|
||||
{ 8, 5, 3 },
|
||||
{ 8, 5, 2 }
|
||||
}
|
||||
},
|
||||
{
|
||||
{-1, -1, -1},
|
||||
{
|
||||
{ 1, 2, 4 },
|
||||
{ 1, 2, 3 },
|
||||
{ 1, 2, 2 }
|
||||
{ 8, 16, 4 },
|
||||
{ 8, 16, 3 },
|
||||
{ 8, 16, 2 }
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -138,36 +107,20 @@ std::vector<ov::test::InputShape> inShapesDynamic4d = {
|
||||
{
|
||||
{-1, -1, -1, -1},
|
||||
{
|
||||
{ 8, 5, 3, 4 },
|
||||
{ 8, 5, 3, 3 },
|
||||
{ 8, 5, 3, 2 }
|
||||
{ 8, 16, 3, 4 },
|
||||
{ 8, 16, 3, 3 },
|
||||
{ 8, 16, 3, 2 }
|
||||
}
|
||||
},
|
||||
{
|
||||
{-1, -1, -1, -1},
|
||||
{
|
||||
{ 1, 2, 3, 4 },
|
||||
{ 1, 2, 3, 3 },
|
||||
{ 1, 2, 3, 2 }
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<ov::test::InputShape> inShapesDynamic5d = {
|
||||
{
|
||||
{ -1, -1, -1, -1, -1 },
|
||||
{
|
||||
{ 8, 5, 3, 2, 4 },
|
||||
{ 8, 5, 3, 2, 3 },
|
||||
{ 8, 5, 3, 2, 2 }
|
||||
}
|
||||
},
|
||||
{
|
||||
{-1, -1, -1, -1, -1},
|
||||
{
|
||||
{ 1, 2, 3, 4, 4 },
|
||||
{ 1, 2, 3, 4, 3 },
|
||||
{ 1, 2, 3, 4, 2 }
|
||||
{ 8, 16, 3, 2, 4 },
|
||||
{ 8, 16, 3, 2, 3 },
|
||||
{ 8, 16, 3, 2, 2 }
|
||||
}
|
||||
}
|
||||
};
|
||||
@ -175,19 +128,19 @@ const auto params5dDynamic = ::testing::Combine(
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inShapesDynamic5d),
|
||||
::testing::ValuesIn(netPrecisions)),
|
||||
::testing::ValuesIn(getCpuInfoForDimsCount(5)));
|
||||
::testing::Values(emptyCPUSpec));
|
||||
|
||||
const auto params4dDynamic = ::testing::Combine(
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inShapesDynamic4d),
|
||||
::testing::ValuesIn(netPrecisions)),
|
||||
::testing::ValuesIn(getCpuInfoForDimsCount(4)));
|
||||
::testing::Values(emptyCPUSpec));
|
||||
|
||||
const auto params3dDynamic = ::testing::Combine(
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inShapesDynamic3d),
|
||||
::testing::ValuesIn(netPrecisions)),
|
||||
::testing::ValuesIn(getCpuInfoForDimsCount(3)));
|
||||
::testing::Values(emptyCPUSpec));
|
||||
|
||||
// We don't check static case, because of constant folding
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf3dDynamicLayoutTest, ShapeOfLayerCPUTest,
|
||||
|
@ -0,0 +1,228 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <openvino/opsets/opset13.hpp>
|
||||
#include <transformations/op_conversions/scaled_dot_product_attention_decomposition.hpp>
|
||||
|
||||
#include "ov_models/builders.hpp"
|
||||
#include "ov_models/utils/ov_helpers.hpp"
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
#include "common_test_utils/include/common_test_utils/ov_tensor_utils.hpp"
|
||||
|
||||
using namespace ov::test;
|
||||
using namespace ngraph;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using ConcatSDPTestParams = std::tuple<ElementType,
|
||||
std::vector<InputShape>,
|
||||
bool // has ShapeOf
|
||||
>;
|
||||
// Subgraph:
|
||||
/* Parameter
|
||||
* |
|
||||
* Parameter ReadValue | ReadValue Parameter
|
||||
* \ / | \ /
|
||||
* \ / | \ /
|
||||
* Concat | Concat
|
||||
* / \ | / \
|
||||
* / \ | / \
|
||||
* / \ | / \
|
||||
* Assign ScaledDotProductAttention Assign
|
||||
* |
|
||||
* Add
|
||||
* |
|
||||
* Result
|
||||
*/
|
||||
|
||||
class ConcatSDPTest : public testing::WithParamInterface<ConcatSDPTestParams>, virtual public ov::test::SubgraphBaseTest, public CPUTestsBase {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<ConcatSDPTestParams>& obj) {
|
||||
ElementType inType;
|
||||
std::vector<InputShape> inputShapes;
|
||||
bool hasShapeof;
|
||||
std::tie(inType, inputShapes, hasShapeof) = obj.param;
|
||||
std::ostringstream result;
|
||||
result << "IS=";
|
||||
for (const auto& shape : inputShapes) {
|
||||
result << ov::test::utils::partialShape2str({shape.first}) << "_";
|
||||
}
|
||||
result << "TS=";
|
||||
for (const auto& shape : inputShapes) {
|
||||
result << "(";
|
||||
if (!shape.second.empty()) {
|
||||
for (const auto& itr : shape.second) {
|
||||
result << ov::test::utils::vec2str(itr);
|
||||
}
|
||||
}
|
||||
result << ")_";
|
||||
}
|
||||
result << "Prc=" << inType << "_";
|
||||
result << "HasShapeOf=" << hasShapeof;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
ElementType inType;
|
||||
std::vector<InputShape> inputShapes;
|
||||
bool hasShapeOf;
|
||||
std::tie(inType, inputShapes, hasShapeOf) = this->GetParam();
|
||||
targetDevice = ov::test::utils::DEVICE_CPU;
|
||||
rel_threshold = 1e-4f;
|
||||
if (inType == ElementType::bf16) {
|
||||
configuration.insert({"ENFORCE_BF16", "YES"});
|
||||
rel_threshold = 0.01f;
|
||||
}
|
||||
init_input_shapes(inputShapes);
|
||||
ov::ParameterVector inputParams;
|
||||
// q,k,v
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[0]));
|
||||
inputParams[0]->set_friendly_name("q");
|
||||
inputParams[1]->set_friendly_name("k");
|
||||
inputParams[2]->set_friendly_name("v");
|
||||
// pastkv init_cost
|
||||
inputParams.push_back(std::make_shared<ov::op::v0::Parameter>(inType, inputDynamicShapes[1]));
|
||||
auto var_k = std::make_shared<ov::op::util::Variable>(
|
||||
ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastk"});
|
||||
auto pastk = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_k);
|
||||
pastk->set_friendly_name("pastk_r");
|
||||
auto var_v = std::make_shared<ov::op::util::Variable>(
|
||||
ov::op::util::VariableInfo{inputDynamicShapes[1], inType, "pastv"});
|
||||
auto pastv = std::make_shared<ov::op::v6::ReadValue>(inputParams[3], var_v);
|
||||
pastv->set_friendly_name("pastv_r");
|
||||
std::shared_ptr<Node> pastk_shapeof, pastv_shapeof;
|
||||
if (hasShapeOf) {
|
||||
pastk_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastk);
|
||||
pastv_shapeof = std::make_shared<ov::op::v0::ShapeOf>(pastv);
|
||||
}
|
||||
auto concatK = std::make_shared<ov::op::v0::Concat>(OutputVector{pastk, inputParams[1]}, 2);
|
||||
auto concatV = std::make_shared<ov::op::v0::Concat>(OutputVector{pastv, inputParams[2]}, 2);
|
||||
auto sdp = std::make_shared<ov::opset13::ScaledDotProductAttention>(inputParams[0], concatK, concatV, false);
|
||||
sdp->set_friendly_name("mha");
|
||||
auto add = std::make_shared<op::v1::Add>(sdp, op::v0::Constant::create(inType, {1}, {1.0f}));
|
||||
auto pastk_assign = std::make_shared<op::v6::Assign>(concatK, var_k);
|
||||
auto pastv_assign = std::make_shared<op::v6::Assign>(concatV, var_v);
|
||||
pastk_assign->set_friendly_name("pastk_w");
|
||||
pastv_assign->set_friendly_name("pastv_w");
|
||||
|
||||
ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
|
||||
if (hasShapeOf) {
|
||||
results.push_back(std::make_shared<ov::op::v0::Result>(pastk_shapeof));
|
||||
results.push_back(std::make_shared<ov::op::v0::Result>(pastv_shapeof));
|
||||
}
|
||||
SinkVector sinks{pastk_assign, pastv_assign};
|
||||
function = std::make_shared<Function>(results, sinks, inputParams, "ConcatSDP");
|
||||
targetDevice = ov::test::utils::DEVICE_CPU;
|
||||
|
||||
functionRefs = function->clone();
|
||||
pass::Manager manager;
|
||||
// decompose ScaledDotProductAttention
|
||||
manager.register_pass<ov::pass::ScaledDotProductAttentionDecomposition>();
|
||||
manager.run_passes(functionRefs);
|
||||
}
|
||||
void generate_inputs(const std::vector<ov::Shape>& targetInputStaticShapes) override {
|
||||
std::vector<ov::Shape> shapes(4);
|
||||
shapes[0] = targetInputStaticShapes[0];
|
||||
shapes[1] = targetInputStaticShapes[0];
|
||||
shapes[2] = targetInputStaticShapes[0];
|
||||
shapes[3] = targetInputStaticShapes[1];
|
||||
SubgraphBaseTest::generate_inputs(shapes);
|
||||
}
|
||||
template<typename IT, typename T>
|
||||
void strided_iota(IT first, size_t n, T value, T stride) {
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
*first++ = value;
|
||||
value += stride;
|
||||
}
|
||||
}
|
||||
void generate(int idx, const std::vector<ov::Shape>& targetInputStaticShapes) {
|
||||
inputs.clear();
|
||||
auto create_input = [this] (std::shared_ptr<op::v0::Parameter> param, ov::Shape shape, float val) {
|
||||
if (param->get_element_type() == element::f32) {
|
||||
ov::Tensor t{ov::element::f32, shape};
|
||||
strided_iota(static_cast<float*>(t.data()), t.get_size(), val, 0.1f);
|
||||
inputs.insert({param, t});
|
||||
} else {
|
||||
ov::Tensor t{ov::element::bf16, shape};
|
||||
strided_iota(static_cast<ov::bfloat16*>(t.data()), t.get_size(), val, 0.1f);
|
||||
inputs.insert({param, t});
|
||||
}
|
||||
};
|
||||
// q, k, v
|
||||
create_input(function->get_parameters()[0], targetInputStaticShapes[0], idx + 1.0f);
|
||||
create_input(function->get_parameters()[1], targetInputStaticShapes[0], idx + 2.0f);
|
||||
create_input(function->get_parameters()[2], targetInputStaticShapes[0], idx + 3.0f);
|
||||
create_input(function->get_parameters()[3], targetInputStaticShapes[1], idx + 4.0f);
|
||||
}
|
||||
void prepare() {
|
||||
compile_model();
|
||||
inferRequest = compiledModel.create_infer_request();
|
||||
ASSERT_TRUE(inferRequest);
|
||||
}
|
||||
void reset() {
|
||||
for (auto&& state : inferRequest.query_state()) {
|
||||
state.reset();
|
||||
}
|
||||
inferRequest = ov::InferRequest();
|
||||
}
|
||||
std::vector<ov::Tensor> run_test(std::shared_ptr<ov::Model> model) {
|
||||
function = model;
|
||||
prepare();
|
||||
std::vector<ov::Tensor> outputs;
|
||||
int idx = 0;
|
||||
for (auto&& shapes : targetStaticShapes) {
|
||||
generate(idx++, shapes);
|
||||
for (const auto& input : inputs) {
|
||||
inferRequest.set_tensor(input.first, input.second);
|
||||
}
|
||||
inferRequest.infer();
|
||||
auto outputTensor = inferRequest.get_output_tensor(0);
|
||||
ov::Tensor copy{outputTensor.get_element_type(), outputTensor.get_shape()};
|
||||
outputTensor.copy_to(copy);
|
||||
outputs.push_back(copy);
|
||||
}
|
||||
reset();
|
||||
|
||||
return outputs;
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConcatSDPTest, CompareWithRefs) {
|
||||
auto actualOutputs = run_test(function);
|
||||
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1);
|
||||
CheckNumberOfNodesWithType(compiledModel, "Concatenation", 0);
|
||||
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
|
||||
auto expectedOutputs = run_test(functionRefs);
|
||||
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 0);
|
||||
for (size_t i = 0; i < actualOutputs.size(); i++) {
|
||||
ov::test::utils::compare(expectedOutputs[i], actualOutputs[i], abs_threshold, rel_threshold);
|
||||
}
|
||||
}
|
||||
|
||||
namespace {
|
||||
const std::vector<std::vector<InputShape>> inputShapes = {
|
||||
// dynamic batch
|
||||
{
|
||||
// B, H, L1, S
|
||||
{{1, 8, -1, 64}, {{1, 8, 10, 64}, {1, 8, 1, 64}, {1, 8, 1, 64}, {1, 8, 20, 64}, {1, 8, 1, 64}}},
|
||||
// B, H, L0, S
|
||||
{{1, 8, -1, 64}, {{1, 8, 0, 64}, {1, 8, 10, 64}, {1, 8, 11, 64}, {1, 8, 12, 64}, {1, 8, 32, 64}}},
|
||||
},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ConcatSDPTest,
|
||||
ConcatSDPTest,
|
||||
::testing::Combine(::testing::Values(ElementType::f32),
|
||||
::testing::ValuesIn(inputShapes),
|
||||
::testing::Values(true, false)),
|
||||
ConcatSDPTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace SubgraphTestsDefinitions
|
@ -0,0 +1,201 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils/cpu_test_utils.hpp"
|
||||
|
||||
#include "ov_models/builders.hpp"
|
||||
#include "ov_models/utils/ov_helpers.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
|
||||
using InputShape = ov::test::InputShape;
|
||||
using ElementType = ov::element::Type_t;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
// ┌────────┐
|
||||
// │ Param │
|
||||
// └───┬────┘
|
||||
// │
|
||||
// │
|
||||
// │
|
||||
// ┌───┴────┐ To simulate different layouts
|
||||
// │ Eltwise│ ◄─────────────────────────────
|
||||
// └───┬────┘
|
||||
// │ No Reorders are expected
|
||||
// │ ◄───────────────────────────
|
||||
// │
|
||||
// ┌───┴────┐
|
||||
// │ShapeOf │
|
||||
// └───┬────┘
|
||||
// │
|
||||
// │
|
||||
// │
|
||||
// ┌───┴────┐
|
||||
// │ Output │
|
||||
// └────────┘
|
||||
|
||||
typedef std::tuple<
|
||||
InputShape,
|
||||
ElementType // Net precision
|
||||
> ShapeOfAnyLayoutParams;
|
||||
|
||||
typedef std::tuple<
|
||||
ShapeOfAnyLayoutParams,
|
||||
CPUSpecificParams
|
||||
> ShapeOfAnyLayoutCPUTestParamsSet;
|
||||
|
||||
class ShapeOfAnyLayoutCPUTest : public testing::WithParamInterface<ShapeOfAnyLayoutCPUTestParamsSet>,
|
||||
virtual public ov::test::SubgraphBaseTest, public CPUTestsBase {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<ShapeOfAnyLayoutCPUTestParamsSet> obj) {
|
||||
SubgraphTestsDefinitions::ShapeOfAnyLayoutParams basicParamsSet;
|
||||
CPUSpecificParams cpuParams;
|
||||
std::tie(basicParamsSet, cpuParams) = obj.param;
|
||||
ElementType netPr;
|
||||
InputShape inputShape;
|
||||
|
||||
std::tie(inputShape, netPr) = basicParamsSet;
|
||||
std::ostringstream result;
|
||||
result << "ShapeOfTest_";
|
||||
result << std::to_string(obj.index) << "_";
|
||||
result << "Prec=" << netPr << "_";
|
||||
result << CPUTestsBase::getTestCaseName(cpuParams) << "_";
|
||||
result << "IS=";
|
||||
result << ov::test::utils::partialShape2str({inputShape.first}) << "_";
|
||||
result << "TS=(";
|
||||
for (const auto& shape : inputShape.second) {
|
||||
result << ov::test::utils::vec2str(shape) << "_";
|
||||
}
|
||||
result << ")";
|
||||
return result.str();
|
||||
}
|
||||
protected:
|
||||
void SetUp() override {
|
||||
targetDevice = ov::test::utils::DEVICE_CPU;
|
||||
|
||||
ShapeOfAnyLayoutParams basicParamsSet;
|
||||
CPUSpecificParams cpuParams;
|
||||
std::tie(basicParamsSet, cpuParams) = this->GetParam();
|
||||
std::vector<cpu_memory_format_t> eltwiseInFmts, eltwiseOutFmts;
|
||||
std::tie(eltwiseInFmts, eltwiseOutFmts, priority, selectedType) = cpuParams;
|
||||
|
||||
auto netPrecision = ElementType::undefined;
|
||||
InputShape inputShape;
|
||||
std::tie(inputShape, netPrecision) = basicParamsSet;
|
||||
init_input_shapes({inputShape});
|
||||
|
||||
inType = ov::element::Type(netPrecision);
|
||||
outType = ElementType::i32;
|
||||
selectedType = makeSelectedTypeStr("ref", inType);
|
||||
|
||||
ov::ParameterVector params;
|
||||
for (auto&& shape : inputDynamicShapes)
|
||||
params.push_back(std::make_shared<ov::op::v0::Parameter>(inType, shape));
|
||||
|
||||
//make a stub eltwise node to enforce layout, since ShapeOf just mimic any input layout
|
||||
auto eltwise = ngraph::builder::makeActivation(params[0], inType, ov::test::utils::ActivationTypes::Relu);
|
||||
eltwise->get_rt_info() = makeCPUInfo(eltwiseInFmts, eltwiseOutFmts, {});
|
||||
|
||||
auto shapeOf = std::make_shared<ngraph::opset3::ShapeOf>(eltwise, ngraph::element::i32);
|
||||
|
||||
function = makeNgraphFunction(netPrecision, params, shapeOf, "ShapeOf");
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ShapeOfAnyLayoutCPUTest, CompareWithRefs) {
|
||||
run();
|
||||
CheckPluginRelatedResults(compiledModel, "ShapeOf");
|
||||
CheckNumberOfNodesWithType(compiledModel, "Reorder", 1);
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
/* CPU PARAMS */
|
||||
std::vector<CPUSpecificParams> getCpuInfoForDimsCount(const size_t dimsCount = 3) {
|
||||
std::vector<CPUSpecificParams> resCPUParams;
|
||||
const bool avx512_target = with_cpu_x86_avx512f();
|
||||
|
||||
if (dimsCount == 5) {
|
||||
auto blocked_format = avx512_target ? nCdhw16c : nCdhw8c;
|
||||
resCPUParams.push_back(CPUSpecificParams{{blocked_format}, {blocked_format}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{ndhwc}, {ndhwc}, {}, {}});
|
||||
} else if (dimsCount == 4) {
|
||||
auto blocked_format = avx512_target ? nChw16c : nChw8c;
|
||||
resCPUParams.push_back(CPUSpecificParams{{blocked_format}, {blocked_format}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{nhwc}, {nhwc}, {}, {}});
|
||||
} else {
|
||||
auto blocked_format = avx512_target ? nCw16c : nCw8c;
|
||||
resCPUParams.push_back(CPUSpecificParams{{blocked_format}, {blocked_format}, {}, {}});
|
||||
resCPUParams.push_back(CPUSpecificParams{{acb}, {acb}, {}, {}});
|
||||
}
|
||||
|
||||
return resCPUParams;
|
||||
}
|
||||
|
||||
const std::vector<ElementType> netPrecisions = {
|
||||
ElementType::f32
|
||||
};
|
||||
|
||||
std::vector<ov::test::InputShape> inShapesDynamic3d = {
|
||||
{
|
||||
{-1, 16, -1},
|
||||
{
|
||||
{ 8, 16, 4 },
|
||||
{ 8, 16, 3 },
|
||||
{ 8, 16, 2 }
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
std::vector<ov::test::InputShape> inShapesDynamic4d = {
|
||||
{
|
||||
{-1, 16, -1, -1},
|
||||
{
|
||||
{ 8, 16, 3, 4 },
|
||||
{ 8, 16, 3, 3 },
|
||||
{ 8, 16, 3, 2 }
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
std::vector<ov::test::InputShape> inShapesDynamic5d = {
|
||||
{
|
||||
{ -1, 16, -1, -1, -1 },
|
||||
{
|
||||
{ 8, 16, 3, 2, 4 },
|
||||
{ 8, 16, 3, 2, 3 },
|
||||
{ 8, 16, 3, 2, 2 }
|
||||
}
|
||||
}
|
||||
};
|
||||
const auto params5dDynamic = ::testing::Combine(
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inShapesDynamic5d),
|
||||
::testing::ValuesIn(netPrecisions)),
|
||||
::testing::ValuesIn(getCpuInfoForDimsCount(5)));
|
||||
|
||||
const auto params4dDynamic = ::testing::Combine(
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inShapesDynamic4d),
|
||||
::testing::ValuesIn(netPrecisions)),
|
||||
::testing::ValuesIn(getCpuInfoForDimsCount(4)));
|
||||
|
||||
const auto params3dDynamic = ::testing::Combine(
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(inShapesDynamic3d),
|
||||
::testing::ValuesIn(netPrecisions)),
|
||||
::testing::ValuesIn(getCpuInfoForDimsCount(3)));
|
||||
|
||||
// We don't check static case, because of constant folding
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf3dAnyLayoutTest, ShapeOfAnyLayoutCPUTest,
|
||||
params3dDynamic, ShapeOfAnyLayoutCPUTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf4dAnyLayoutTest, ShapeOfAnyLayoutCPUTest,
|
||||
params4dDynamic, ShapeOfAnyLayoutCPUTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_ShapeOf5dAnyLayoutTest, ShapeOfAnyLayoutCPUTest,
|
||||
params5dDynamic, ShapeOfAnyLayoutCPUTest::getTestCaseName);
|
||||
} // namespace
|
||||
} // namespace SubgraphTestsDefinitions
|
161
src/plugins/intel_cpu/tests/unit/graph/scaled_attn.cpp
Normal file
161
src/plugins/intel_cpu/tests/unit/graph/scaled_attn.cpp
Normal file
@ -0,0 +1,161 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <common/blocked_desc_creator.h>
|
||||
#include <cpu_types.h>
|
||||
#include <edge.h>
|
||||
#include <gtest/gtest.h>
|
||||
#include <ie_common.h>
|
||||
#include <memory_desc/cpu_memory_desc_utils.h>
|
||||
#include <memory_desc/dnnl_memory_desc.h>
|
||||
#include <node.h>
|
||||
#include <nodes/reorder.h>
|
||||
|
||||
#include <common/memory_desc_wrapper.hpp>
|
||||
#include <dnnl.hpp>
|
||||
#include <utility>
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "cache/multi_cache.h"
|
||||
#include "ov_models/builders.hpp"
|
||||
#include "nodes/scaled_attn.h"
|
||||
#include "nodes/input.h"
|
||||
#include "graph.h"
|
||||
#include "cpu_tensor.h"
|
||||
|
||||
using namespace ov::intel_cpu;
|
||||
|
||||
TEST(ScaledAttnGraphTest, smoke_Check_Scaled_Concat_Noplace) {
|
||||
auto build_graph = [](const ov::Shape& shape, float* qkv_val, float* past_kv_val) {
|
||||
auto qkv = ov::op::v0::Constant::create(ov::element::f32, shape, qkv_val);
|
||||
qkv->set_friendly_name("qkv_const");
|
||||
auto pastkv = ov::op::v0::Constant::create(ov::element::f32, shape, past_kv_val);
|
||||
pastkv->set_friendly_name("pastkv_const");
|
||||
// only need a dynamic parameter but its value will not be used
|
||||
auto attn = std::make_shared<ov::op::v0::Parameter>(ov::element::f32, ov::PartialShape{-1});
|
||||
attn->set_friendly_name("attn");
|
||||
|
||||
ov::intel_cpu::ScaledDotProductAttentionStub::Config config;
|
||||
config.fuse_concat = true;
|
||||
config.is_causal = true;
|
||||
auto sdpa = std::make_shared<ov::intel_cpu::ScaledDotProductAttentionStub>(ov::OutputVector{qkv, qkv, qkv, attn, pastkv, pastkv}, config);
|
||||
auto out_qkv = std::make_shared<ov::op::v0::Result>(sdpa->output(0));
|
||||
out_qkv->set_friendly_name("qkv");
|
||||
auto out_pastk = std::make_shared<ov::op::v0::Result>(sdpa->output(1));
|
||||
out_pastk->set_friendly_name("pastk");
|
||||
auto out_pastv = std::make_shared<ov::op::v0::Result>(sdpa->output(2));
|
||||
out_pastv->set_friendly_name("pastv");
|
||||
|
||||
std::unordered_set<NodePtr> nodes_set;
|
||||
std::vector<EdgePtr> graph_edges;
|
||||
|
||||
auto add_edge = [&](const NodePtr& parent, const NodePtr& child, size_t parentPort, size_t childPort) -> void {
|
||||
auto edge = std::make_shared<Edge>(parent, child, parentPort, childPort);
|
||||
child->addEdge(edge);
|
||||
graph_edges.push_back(edge);
|
||||
nodes_set.insert(parent);
|
||||
nodes_set.insert(child);
|
||||
};
|
||||
|
||||
//create graph context
|
||||
Config conf;
|
||||
conf.rtCacheCapacity = 0;
|
||||
auto context = std::make_shared<GraphContext>(conf, nullptr, nullptr, false);
|
||||
|
||||
auto qkv_node = std::make_shared<node::Input>(qkv, context);
|
||||
auto pastkv_node = std::make_shared<node::Input>(pastkv, context);
|
||||
auto attn_node = std::make_shared<node::Input>(attn, context);
|
||||
auto sdpa_node = std::make_shared<node::ScaledDotProductAttention>(sdpa, context);
|
||||
auto out_qkv_node = std::make_shared<node::Input>(out_qkv, context);
|
||||
auto out_pastk_node = std::make_shared<node::Input>(out_pastk, context);
|
||||
auto out_pastv_node = std::make_shared<node::Input>(out_pastv, context);
|
||||
|
||||
add_edge(qkv_node, sdpa_node, 0, 0);
|
||||
add_edge(qkv_node, sdpa_node, 0, 1);
|
||||
add_edge(qkv_node, sdpa_node, 0, 2);
|
||||
add_edge(attn_node, sdpa_node, 0, 3);
|
||||
add_edge(pastkv_node, sdpa_node, 0, 4);
|
||||
add_edge(pastkv_node, sdpa_node, 0, 5);
|
||||
add_edge(sdpa_node, out_qkv_node, 0, 0);
|
||||
add_edge(sdpa_node, out_pastk_node, 1, 0);
|
||||
add_edge(sdpa_node, out_pastv_node, 2, 0);
|
||||
|
||||
std::vector<NodePtr> graph_nodes(nodes_set.begin(), nodes_set.end());
|
||||
|
||||
Graph graph;
|
||||
graph.CreateGraph(graph_nodes, graph_edges, context, "test_graph");
|
||||
return graph;
|
||||
};
|
||||
|
||||
auto run_graph = [] (Graph& graph) {
|
||||
graph.GetInputNodesMap().begin()->second->redefineOutputMemory(0, {1});
|
||||
|
||||
for (auto& node : graph.GetNodes()) {
|
||||
if (node->isDynamicNode()) {
|
||||
node->updateShapes();
|
||||
node->updateDynamicParams();
|
||||
}
|
||||
}
|
||||
graph.Infer();
|
||||
};
|
||||
|
||||
auto check_graph = [] (Graph& graph, std::map<std::string, std::pair<float*, ov::Shape>>& expected) {
|
||||
auto& outputNodesMap = graph.GetOutputNodesMap();
|
||||
auto is_same = [] (float a, float b) {
|
||||
return std::abs(a - b) < 0.0001f;
|
||||
};
|
||||
for (auto &outputMap : outputNodesMap) {
|
||||
auto name = outputMap.first;
|
||||
if (expected.count(name) == 0) {
|
||||
continue;
|
||||
}
|
||||
auto node = outputMap.second;
|
||||
auto parentEdge = node->getParentEdgeAt(0);
|
||||
const auto& memory = parentEdge->getMemoryPtr();
|
||||
auto size = memory->getSize() / sizeof(float);
|
||||
auto p = reinterpret_cast<float*>(memory->getData());
|
||||
for (size_t i = 0; i < size; i++) {
|
||||
ASSERT_EQ(is_same(p[i], expected.at(name).first[i]), true);
|
||||
}
|
||||
ASSERT_EQ(memory->getShape(), ov::intel_cpu::Shape(expected.at(name).second));
|
||||
}
|
||||
};
|
||||
|
||||
auto find_node_type = [](const Graph& graph, Type type) -> NodePtr {
|
||||
auto&& nodes = graph.GetNodes();
|
||||
auto itr =
|
||||
std::find_if(nodes.begin(), nodes.end(), [=](const NodePtr& node){ return type == node->getType(); });
|
||||
|
||||
if (itr == nodes.end()) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
return (*itr);
|
||||
};
|
||||
|
||||
auto strided_iota = [] (float* first, size_t n, float value, float stride) {
|
||||
for (size_t i = 0; i < n; i++) {
|
||||
*first++ = value;
|
||||
value += stride;
|
||||
}
|
||||
};
|
||||
|
||||
ov::Shape shape{1, 1, 8, 8};
|
||||
const size_t elements_count = std::accumulate(shape.begin(), shape.end(), size_t{1}, std::multiplies<size_t>());
|
||||
std::vector<float> val(elements_count * 2);
|
||||
strided_iota(val.data(), val.size(), -10.0f, 0.1f);
|
||||
auto graph = build_graph(shape, val.data() + elements_count, val.data());
|
||||
run_graph(graph);
|
||||
// if no inplace, the pastk and pastv will concat, check shape and value
|
||||
ov::Shape expectedShape(shape);
|
||||
expectedShape[2] *= 2;
|
||||
std::map<std::string, std::pair<float*, ov::Shape>> expected{
|
||||
{"pastk", std::make_pair(val.data(), expectedShape)},
|
||||
{"pastv", std::make_pair(val.data(), expectedShape)}};
|
||||
check_graph(graph, expected);
|
||||
auto spd = find_node_type(graph, Type::ScaledDotProductAttention)->getSelectedPrimitiveDescriptor();
|
||||
ASSERT_EQ(spd->getConfig().outConfs[1].inPlace(), -1);
|
||||
ASSERT_EQ(spd->getConfig().outConfs[2].inPlace(), -1);
|
||||
}
|
@ -0,0 +1,105 @@
|
||||
// Copyright (C) 2018-2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include <openvino/opsets/opset13.hpp>
|
||||
#include <transformations/cpu_opset/common/op/sdp.hpp>
|
||||
#include <transformations/cpu_opset/common/pass/stateful_sdp_fusion.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <openvino/pass/manager.hpp>
|
||||
#include <ie_core.hpp>
|
||||
|
||||
#include "common_test_utils/ov_test_utils.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ov::intel_cpu;
|
||||
using namespace ov;
|
||||
|
||||
static std::shared_ptr<ov::Model> makeSDPA(const ov::PartialShape& inputShape, bool isRef = false, bool hasConvert = false) {
|
||||
auto q = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
|
||||
auto k = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
|
||||
auto v = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
|
||||
auto init = std::make_shared<ov::op::v0::Parameter>(element::f32, inputShape);
|
||||
auto var_k = std::make_shared<ov::op::util::Variable>(
|
||||
ov::op::util::VariableInfo{inputShape, element::f32, "pastk"});
|
||||
std::shared_ptr<ov::Node> pastk = std::make_shared<ov::op::v6::ReadValue>(k, var_k);
|
||||
auto var_v = std::make_shared<ov::op::util::Variable>(
|
||||
ov::op::util::VariableInfo{inputShape, element::f32, "pastv"});
|
||||
std::shared_ptr<ov::Node> pastv = std::make_shared<ov::op::v6::ReadValue>(v, var_v);
|
||||
Output<ov::Node> concatK, concatV, sdp;
|
||||
if (hasConvert) {
|
||||
pastk = std::make_shared<ov::op::v0::Convert>(pastk, element::f32);
|
||||
pastv = std::make_shared<ov::op::v0::Convert>(pastv, element::f32);
|
||||
}
|
||||
if (isRef) {
|
||||
ov::intel_cpu::ScaledDotProductAttentionStub::Config config;
|
||||
config.fuse_concat = true;
|
||||
auto new_node = std::make_shared<ov::intel_cpu::ScaledDotProductAttentionStub>(OutputVector{q, k, v, pastk, pastv}, config);
|
||||
sdp = new_node->output(0);
|
||||
concatK = new_node->output(1);
|
||||
concatV = new_node->output(2);
|
||||
} else {
|
||||
concatK = std::make_shared<ov::op::v0::Concat>(OutputVector{pastk, k}, 2);
|
||||
concatV = std::make_shared<ov::op::v0::Concat>(OutputVector{pastv, v}, 2);
|
||||
sdp = std::make_shared<ov::opset13::ScaledDotProductAttention>(q, concatK, concatV, false);
|
||||
}
|
||||
if (hasConvert) {
|
||||
concatK = std::make_shared<ov::op::v0::Convert>(concatK, element::f32);
|
||||
concatV = std::make_shared<ov::op::v0::Convert>(concatV, element::f32);
|
||||
}
|
||||
auto pastk_assign = std::make_shared<op::v6::Assign>(concatK, var_k);
|
||||
auto pastv_assign = std::make_shared<op::v6::Assign>(concatV, var_v);
|
||||
auto add = std::make_shared<op::v1::Add>(sdp, op::v0::Constant::create(element::f32, {1}, {1.0f}));
|
||||
|
||||
ResultVector results{std::make_shared<ov::op::v0::Result>(add)};
|
||||
SinkVector sinks{pastk_assign, pastv_assign};
|
||||
return std::make_shared<Model>(results, sinks, ParameterVector{q, k, v, init}, "ConcatSDP");
|
||||
}
|
||||
|
||||
TEST(TransformationTests, StateConcatSDPA) {
|
||||
std::shared_ptr<ov::Model> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
using namespace ov;
|
||||
auto inputShape = ov::PartialShape{-1, 8, -1, 64};
|
||||
{
|
||||
f = makeSDPA(inputShape);
|
||||
pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<StatefulSDPFusion>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
//construct ref interaction
|
||||
{
|
||||
f_ref = makeSDPA(inputShape, true);
|
||||
}
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(TransformationTests, StateConcatSDPAWithConvert) {
|
||||
std::shared_ptr<ov::Model> f(nullptr), f_ref(nullptr);
|
||||
{
|
||||
using namespace ov;
|
||||
auto inputShape = ov::PartialShape{-1, 8, -1, 64};
|
||||
{
|
||||
f = makeSDPA(inputShape, false, true);
|
||||
pass::Manager m;
|
||||
m.register_pass<ov::pass::InitNodeInfo>();
|
||||
m.register_pass<StatefulSDPFusion>();
|
||||
m.run_passes(f);
|
||||
}
|
||||
//construct ref interaction
|
||||
{
|
||||
f_ref = makeSDPA(inputShape, true, true);
|
||||
}
|
||||
auto res = compare_functions(f, f_ref);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user