[CPU] Use primitive priority list more efficiently (#17135)

This commit is contained in:
Egor Duplenskii 2023-06-13 15:55:27 +02:00 committed by GitHub
parent d95c49d888
commit e738c4e83f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
49 changed files with 647 additions and 584 deletions

View File

@ -179,21 +179,11 @@ std::string DnnlExtensionUtils::query_impl_info_str(const const_dnnl_primitive_d
return std::string(res); return std::string(res);
} }
bool DnnlExtensionUtils::find_implementation(dnnl::primitive_desc& desc, impl_desc_type implType) { bool DnnlExtensionUtils::find_implementation(dnnl::primitive_desc& desc, impl_desc_type impl_type) {
primitive_desc_iterator& itpd = desc; return DnnlExtensionUtils::find_implementation(desc,
[impl_type](impl_desc_type cur_impl_type){
while (itpd) { return cur_impl_type == impl_type;
const impl_desc_type descImplType = parse_impl_name(itpd.impl_info_str()); });
if (descImplType == implType) {
return true;
}
if (!itpd.next_impl())
break;
}
return false;
} }
dnnl_memory_desc_t DnnlExtensionUtils::clone_desc(const_dnnl_memory_desc_t cdesc) { dnnl_memory_desc_t DnnlExtensionUtils::clone_desc(const_dnnl_memory_desc_t cdesc) {
@ -202,6 +192,12 @@ dnnl_memory_desc_t DnnlExtensionUtils::clone_desc(const_dnnl_memory_desc_t cdesc
return cloned_md; return cloned_md;
} }
dnnl_primitive_desc_t DnnlExtensionUtils::clone_primitive_desc(const_dnnl_primitive_desc_t cprim_desc) {
dnnl_primitive_desc_t cloned_md = nullptr;
dnnl_primitive_desc_clone(&cloned_md, cprim_desc);
return cloned_md;
}
const char* DnnlExtensionUtils::query_pd_info(const_dnnl_primitive_desc_t pd) { const char* DnnlExtensionUtils::query_pd_info(const_dnnl_primitive_desc_t pd) {
return pd->info(); return pd->info();
} }

View File

@ -54,7 +54,47 @@ public:
static std::shared_ptr<DnnlMemoryDesc> query_md(const const_dnnl_primitive_desc_t& pd, const dnnl::query& what, int idx = 0); static std::shared_ptr<DnnlMemoryDesc> query_md(const const_dnnl_primitive_desc_t& pd, const dnnl::query& what, int idx = 0);
static std::string query_impl_info_str(const const_dnnl_primitive_desc_t& pd); static std::string query_impl_info_str(const const_dnnl_primitive_desc_t& pd);
template<typename T>
static bool find_implementation(dnnl::primitive_desc& desc, T&& comparator) {
dnnl::primitive_desc_iterator& itpd = desc;
while (itpd) {
const impl_desc_type descImplType = parse_impl_name(itpd.impl_info_str());
if (comparator(descImplType)) {
return true;
}
if (!itpd.next_impl())
break;
}
return false;
}
template<typename T, typename L>
static void for_each_implementation(dnnl::primitive_desc& desc, bool first_match, T&& comparator, L&& func) {
dnnl::primitive_desc_iterator& itpd = desc;
while (itpd) {
const impl_desc_type descImplType = parse_impl_name(itpd.impl_info_str());
if (comparator(descImplType)) {
func(itpd);
if (first_match)
break;
}
if (!itpd.next_impl())
break;
}
return;
}
static bool find_implementation(dnnl::primitive_desc& desc, impl_desc_type implType); static bool find_implementation(dnnl::primitive_desc& desc, impl_desc_type implType);
static dnnl_primitive_desc_t clone_primitive_desc(const_dnnl_primitive_desc_t cprim_desc);
static dnnl_memory_desc_t clone_desc(const_dnnl_memory_desc_t cdesc); static dnnl_memory_desc_t clone_desc(const_dnnl_memory_desc_t cdesc);
static const char* query_pd_info(const_dnnl_primitive_desc_t pd); static const char* query_pd_info(const_dnnl_primitive_desc_t pd);
static dnnl::algorithm convertToDnnlAlgorithm(Algorithm alg); static dnnl::algorithm convertToDnnlAlgorithm(Algorithm alg);

View File

@ -8,6 +8,12 @@
namespace ov { namespace ov {
namespace intel_cpu { namespace intel_cpu {
/* c++11 requires to have a definition in cpp file */
constexpr BlockedMemoryDesc::CmpMask BlockedMemoryDesc::FULL_MASK;
constexpr BlockedMemoryDesc::CmpMask BlockedMemoryDesc::EMPTY_MASK;
constexpr BlockedMemoryDesc::CmpMask BlockedMemoryDesc::SKIP_OFFSET_MASK;
constexpr size_t BlockedMemoryDesc::OFFSET_MASK_POS;
bool BlockedMemoryDesc::isCompatibleInternal(const BlockedMemoryDesc &rhs, CmpMask cmpMask) const { bool BlockedMemoryDesc::isCompatibleInternal(const BlockedMemoryDesc &rhs, CmpMask cmpMask) const {
if (this->getShape() != rhs.getShape() || this->getPrecision() != rhs.getPrecision()) if (this->getShape() != rhs.getShape() || this->getPrecision() != rhs.getPrecision())
return false; return false;
@ -35,7 +41,7 @@ bool BlockedMemoryDesc::isCompatibleInternal(const BlockedMemoryDesc &rhs, CmpMa
return false; return false;
} }
if (cmpMask.test(BLOCKED_DESC_OFFSET_MASK_POS)) { if (cmpMask.test(OFFSET_MASK_POS)) {
return dimsEqualWeak(this->getOffsetPadding(), rhs.getOffsetPadding()); return dimsEqualWeak(this->getOffsetPadding(), rhs.getOffsetPadding());
} }

View File

@ -11,17 +11,17 @@
namespace ov { namespace ov {
namespace intel_cpu { namespace intel_cpu {
#define BLOCKED_DESC_FULL_MASK 0xffffffff
#define BLOCKED_DESC_EMPTY_MASK 0x0
#define BLOCKED_DESC_SKIP_OFFSET_MASK 0x7fffffff
#define BLOCKED_DESC_OFFSET_MASK_POS 31
class BlockedMemoryDesc : public virtual MemoryDesc { class BlockedMemoryDesc : public virtual MemoryDesc {
public: public:
using CmpMask = std::bitset<32>; using CmpMask = std::bitset<32>;
public: public:
BlockedMemoryDesc() {} BlockedMemoryDesc() = default;
static constexpr CmpMask FULL_MASK{0xffffffff};
static constexpr CmpMask EMPTY_MASK{0x0};
static constexpr CmpMask SKIP_OFFSET_MASK{0x7fffffff};
static constexpr size_t OFFSET_MASK_POS{31};
/** /**
* @brief Returns the blocked dimensions * @brief Returns the blocked dimensions
@ -76,7 +76,7 @@ public:
virtual bool isCompatible(const BlockedMemoryDesc &rhs, CmpMask cmpMask) const = 0; virtual bool isCompatible(const BlockedMemoryDesc &rhs, CmpMask cmpMask) const = 0;
using MemoryDesc::isCompatible; using MemoryDesc::isCompatible;
virtual ~BlockedMemoryDesc() = default; ~BlockedMemoryDesc() override = default;
std::string serializeFormat() const override; std::string serializeFormat() const override;
@ -88,7 +88,7 @@ protected:
* Doesn't perform descs specific attributes check * Doesn't perform descs specific attributes check
* @return true if compatible, otherwise false * @return true if compatible, otherwise false
*/ */
bool isCompatibleInternal(const BlockedMemoryDesc &rhs, CmpMask cmpMask = BLOCKED_DESC_FULL_MASK) const; bool isCompatibleInternal(const BlockedMemoryDesc &rhs, CmpMask cmpMask = FULL_MASK) const;
mutable VectorDims blockedDims; mutable VectorDims blockedDims;
mutable VectorDims strides; mutable VectorDims strides;

View File

@ -24,8 +24,8 @@ public:
bool isCompatible(const MemoryDesc& rhs) const override; bool isCompatible(const MemoryDesc& rhs) const override;
bool isCompatible(const BlockedMemoryDesc& rhs, CmpMask cmpMask) const override; bool isCompatible(const BlockedMemoryDesc& rhs, CmpMask cmpMask) const override;
bool isCompatible(const CpuBlockedMemoryDesc &rhs, CmpMask cmpMask = BLOCKED_DESC_FULL_MASK) const; bool isCompatible(const CpuBlockedMemoryDesc &rhs, CmpMask cmpMask = BlockedMemoryDesc::FULL_MASK) const;
bool isCompatible(const DnnlBlockedMemoryDesc &rhs, CmpMask cmpMask = BLOCKED_DESC_FULL_MASK) const; bool isCompatible(const DnnlBlockedMemoryDesc &rhs, CmpMask cmpMask = BlockedMemoryDesc::FULL_MASK) const;
InferenceEngine::Precision getPrecision() const override { InferenceEngine::Precision getPrecision() const override {
return precision; return precision;
@ -92,7 +92,7 @@ private:
MemoryDescPtr cloneWithNewDimsImp(const VectorDims& dims) const override; MemoryDescPtr cloneWithNewDimsImp(const VectorDims& dims) const override;
void setPrecision(InferenceEngine::Precision prc) override { void setPrecision(InferenceEngine::Precision prc) override {
precision = std::move(prc); precision = prc;
} }
private: private:

View File

@ -279,7 +279,7 @@ bool DnnlBlockedMemoryDesc::isCompatible(const DnnlBlockedMemoryDesc& rhs, CmpMa
return false; return false;
const uint64_t stride_mask = (0xffffffffffffffff << cmpMask.size()) | cmpMask.to_ullong(); const uint64_t stride_mask = (0xffffffffffffffff << cmpMask.size()) | cmpMask.to_ullong();
const bool checkOffset = cmpMask.test(BLOCKED_DESC_OFFSET_MASK_POS); const bool checkOffset = cmpMask.test(OFFSET_MASK_POS);
const auto thisExtra = wrappedThis.extra(); const auto thisExtra = wrappedThis.extra();
const auto rhsExtra = wrappedRhs.extra(); const auto rhsExtra = wrappedRhs.extra();

View File

@ -28,8 +28,8 @@ public:
bool isCompatible(const MemoryDesc& rhs) const override; bool isCompatible(const MemoryDesc& rhs) const override;
bool isCompatible(const BlockedMemoryDesc& rhs, CmpMask cmpMask) const override; bool isCompatible(const BlockedMemoryDesc& rhs, CmpMask cmpMask) const override;
bool isCompatible(const CpuBlockedMemoryDesc &rhs, CmpMask cmpMask = BLOCKED_DESC_FULL_MASK) const; bool isCompatible(const CpuBlockedMemoryDesc &rhs, CmpMask cmpMask = FULL_MASK) const;
bool isCompatible(const DnnlBlockedMemoryDesc &rhs, CmpMask cmpMask = BLOCKED_DESC_FULL_MASK) const; bool isCompatible(const DnnlBlockedMemoryDesc &rhs, CmpMask cmpMask = FULL_MASK) const;
const VectorDims& getBlockDims() const override { const VectorDims& getBlockDims() const override {
return blockedDims; return blockedDims;

View File

@ -87,13 +87,13 @@ Node::Node(const std::shared_ptr<ngraph::Node>& op,
temporary(false), temporary(false),
constant(ConstantType::Unknown), constant(ConstantType::Unknown),
context(ctx), context(ctx),
algorithm(Algorithm::Default),
fusingPort(-1),
engine(ctx->getEngine()), engine(ctx->getEngine()),
name(op->get_friendly_name()), name(op->get_friendly_name()),
typeStr(op->get_type_name()), typeStr(op->get_type_name()),
type(TypeFromName(op->get_type_name())), type(TypeFromName(op->get_type_name())),
profiling(op->get_friendly_name()) { profiling(op->get_friendly_name()) {
algorithm = Algorithm::Default;
fusingPort = -1;
const std::string errorPrefix = "Ngraph operation " + std::string(op->get_type_name()) + " with name " + op->get_friendly_name(); const std::string errorPrefix = "Ngraph operation " + std::string(op->get_type_name()) + " with name " + op->get_friendly_name();
for (size_t i = 0; i < op->get_input_size(); i++) { for (size_t i = 0; i < op->get_input_size(); i++) {
@ -139,18 +139,21 @@ Node::Node(const std::shared_ptr<ngraph::Node>& op,
addOriginalLayer(name); addOriginalLayer(name);
} }
auto primitivesPriority = getPrimitivesPriorityValue(op); auto primitivesPriority = getImplPriorityValue(op);
if (!primitivesPriority.empty()) { if (!primitivesPriority.empty()) {
std::istringstream stream(primitivesPriority); std::istringstream stream(primitivesPriority);
std::string str; std::string str;
while (getline(stream, str, ',')) { while (getline(stream, str, ',')) {
if (str.substr(0, 4) != "cpu:") if (str.substr(0, 4) != "cpu:")
continue; continue;
implPriorities.push_back(parse_impl_name(str)); customImplPriorities.push_back(parse_impl_name(str));
if (implPriorities[implPriorities.size() - 1] == impl_desc_type::unknown && if (customImplPriorities.back() == impl_desc_type::unknown &&
str != "cpu:unknown") str != "cpu:unknown")
IE_THROW() << "Unsupported CPU implementation " << str << " for node " << getName(); IE_THROW() << "Unsupported CPU implementation " << str << " for node " << getName();
} }
// add default primitive priorities as a fallback for the custom ones
const auto& defaultImplPriorities = getDefaultImplPriority();
customImplPriorities.insert(customImplPriorities.end(), defaultImplPriorities.begin(), defaultImplPriorities.end());
} }
std::string inputMemoryFormats = getInputMemoryFormats(op); std::string inputMemoryFormats = getInputMemoryFormats(op);
@ -262,7 +265,7 @@ void Node::createPrimitive() {
} }
void Node::selectOptimalPrimitiveDescriptor() { void Node::selectOptimalPrimitiveDescriptor() {
selectPreferPrimitiveDescriptor(getPrimitivesPriority(), false); selectPreferPrimitiveDescriptor(getImplPriority(), false);
} }
void Node::selectPreferPrimitiveDescriptor(const std::vector<impl_desc_type>& priority, bool ignoreConstInputs) { void Node::selectPreferPrimitiveDescriptor(const std::vector<impl_desc_type>& priority, bool ignoreConstInputs) {
@ -621,44 +624,51 @@ void Node::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty()) if (!supportedPrimitiveDescriptors.empty())
return; return;
auto attr = initPrimitiveAttr(); auto addSupportedPrimitiveDescriptor = [&](const dnnl::primitive_desc& prim_desc) {
std::vector<PortConfig> inConfs, outConfs;
const int inPlaceOutPort = canBeInPlace() ? 0 : -1;
for (auto& desc : descs) {
primitive_desc_iterator itpd = desc;
while (static_cast<bool>(itpd)) {
NodeConfig config;
for (size_t i = 0; i < descInputNumbers(); i++) { for (size_t i = 0; i < descInputNumbers(); i++) {
PortConfig portConfig; auto desc = getSrcMemDesc(prim_desc, i);
portConfig.inPlace(-1);
portConfig.constant(false); inConfs.emplace_back(desc, BlockedMemoryDesc::EMPTY_MASK);
auto desc = getSrcMemDesc(itpd, i);
if (desc->getType() & MemoryDescType::Blocked) {
portConfig.setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), BLOCKED_DESC_EMPTY_MASK);
} else {
portConfig.setMemDesc(std::move(desc));
}
config.inConfs.push_back(portConfig);
} }
for (size_t i = 0; i < descOutputNumbers(); i++) { for (size_t i = 0; i < descOutputNumbers(); i++) {
PortConfig portConfig; auto desc = getDstMemDesc(prim_desc, i);
portConfig.inPlace(canBeInPlace() ? 0 : -1);
portConfig.constant(false); outConfs.emplace_back(desc, BlockedMemoryDesc::EMPTY_MASK, inPlaceOutPort);
auto desc = getDstMemDesc(itpd, i);
if (desc->getType() & MemoryDescType::Blocked) {
portConfig.setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), BLOCKED_DESC_EMPTY_MASK);
} else {
portConfig.setMemDesc(std::move(desc));
} }
config.outConfs.push_back(portConfig);
} const NodeConfig config(inConfs, outConfs);
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str()); const impl_desc_type impl_type = parse_impl_name(prim_desc.impl_info_str());
supportedPrimitiveDescriptors.emplace_back(config, impl_type); supportedPrimitiveDescriptors.emplace_back(config, impl_type);
if (!itpd.next_impl()) };
break;
} /* When custom implementation priorities are NOT defined it is enough to
* just use the first implementation from the priority list.
* When custom implementation priorities are defined, all the implementations should be considered,
* since custom implementations can be not available at all, so a fallback to the default ones must happen
* To achive the fallback, it is necessary to create a supported primitive descriptor for each implementation
* since oneDNN primitive is mutating while iterating */
for (auto& desc : descs) {
auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get()));
const bool first_match = customImplPriorities.empty();
DnnlExtensionUtils::for_each_implementation(desc,
first_match,
[&](impl_desc_type implType) {
return contains(getImplPriority(), implType);
},
[&](dnnl::primitive_desc& desc) {
addSupportedPrimitiveDescriptor(desc);
});
// fallback. if none of the primitive types is present in the priority list just add first implementation
// @todo this fallback is not necessary if primitive priority list is filled correctly
if (supportedPrimitiveDescriptors.empty())
addSupportedPrimitiveDescriptor(first_desc);
} }
} }
@ -971,8 +981,8 @@ void Node::cleanup() {
} }
} }
const std::vector<impl_desc_type>& Node::getPrimitivesPriority() { const std::vector<impl_desc_type>& Node::getDefaultImplPriority() {
std::vector<impl_desc_type> priorities = { static const std::vector<impl_desc_type> priorities {
impl_desc_type::unknown, impl_desc_type::unknown,
// Undef impl type is used to express use-cases there real type is unkown during compilation // Undef impl type is used to express use-cases there real type is unkown during compilation
// Undef has higher priority than defined types in order to force primitive selection logic to make decision based on other properties // Undef has higher priority than defined types in order to force primitive selection logic to make decision based on other properties
@ -1011,11 +1021,16 @@ const std::vector<impl_desc_type>& Node::getPrimitivesPriority() {
impl_desc_type::ref_any, impl_desc_type::ref_any,
impl_desc_type::ref, impl_desc_type::ref,
}; };
for (const auto& impl : priorities) {
if (std::find(implPriorities.begin(), implPriorities.end(), impl) == implPriorities.end()) return priorities;
implPriorities.push_back(impl); }
}
return implPriorities; const std::vector<impl_desc_type>& Node::getImplPriority() {
if (!customImplPriorities.empty())
return customImplPriorities;
return getDefaultImplPriority();
} }
PortDescBasePtr Node::getConsistentInputDesc(const NodeConfig &config, size_t idx) const { PortDescBasePtr Node::getConsistentInputDesc(const NodeConfig &config, size_t idx) const {
@ -1126,7 +1141,7 @@ void Node::initOptimalPrimitiveDescriptor() {
// it is assumed that the nodes will define dense tensors on output edges // it is assumed that the nodes will define dense tensors on output edges
// if it is not the case the implementation must redefine this behaviour // if it is not the case the implementation must redefine this behaviour
if (outMemDesc->getType() & Blocked) { if (outMemDesc->getType() & Blocked) {
config.outConfs[i].setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(outMemDesc), BLOCKED_DESC_FULL_MASK); config.outConfs[i].setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(outMemDesc), BlockedMemoryDesc::FULL_MASK);
} }
} }
} }
@ -1144,18 +1159,18 @@ bool Node::isConfigDefined(const NodeConfig &config) const {
return true; return true;
} }
MemoryDescPtr Node::getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) { MemoryDescPtr Node::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
if (getInputShapeAtPort(idx).isDynamic()) { if (getInputShapeAtPort(idx).isDynamic()) {
return DnnlExtensionUtils::makeUndefinedDesc(primitive_desc_it.src_desc(idx), getInputShapeAtPort(idx)); return DnnlExtensionUtils::makeUndefinedDesc(prim_desc.src_desc(idx), getInputShapeAtPort(idx));
} }
return DnnlExtensionUtils::makeDescriptor(primitive_desc_it.src_desc(idx)); return DnnlExtensionUtils::makeDescriptor(prim_desc.src_desc(idx));
} }
MemoryDescPtr Node::getDstMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) { MemoryDescPtr Node::getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
if (getOutputShapeAtPort(idx).isDynamic()) { if (getOutputShapeAtPort(idx).isDynamic()) {
return DnnlExtensionUtils::makeUndefinedDesc(primitive_desc_it.dst_desc(idx), getOutputShapeAtPort(idx)); return DnnlExtensionUtils::makeUndefinedDesc(prim_desc.dst_desc(idx), getOutputShapeAtPort(idx));
} }
return DnnlExtensionUtils::makeDescriptor(primitive_desc_it.dst_desc(idx)); return DnnlExtensionUtils::makeDescriptor(prim_desc.dst_desc(idx));
} }
void Node::appendPostOpArgs(const dnnl::primitive_attr& attr, void Node::appendPostOpArgs(const dnnl::primitive_attr& attr,
@ -1627,15 +1642,16 @@ void Node::addSupportedPrimDesc(const std::vector<PortConfigurator>& inPortConfi
if (!fill_port(outPortConfigs[i], dims, prc, config.outConfs)) if (!fill_port(outPortConfigs[i], dims, prc, config.outConfs))
return; return;
} }
supportedPrimitiveDescriptors.push_back({config, implType});
supportedPrimitiveDescriptors.emplace_back(config, implType);
} }
void Node::initializeDQScales(const float* scaleData, const size_t scaleSize) { void Node::initializeDQScales(const float* scaleData, const size_t scaleSize) {
bool scalePerTensor;
if (!DQScales.empty() || !scaleSize) if (!DQScales.empty() || !scaleSize)
IE_THROW() << "DQ scales is preset or scale size is 0, ##" << getName(); IE_THROW() << "DQ scales is preset or scale size is 0, ##" << getName();
DQScales.reserve(scaleSize); DQScales.reserve(scaleSize);
scalePerTensor = true;
bool scalePerTensor = true;
for (size_t i = 0; i < scaleSize; i++) { for (size_t i = 0; i < scaleSize; i++) {
DQScales.push_back(scaleData[i]); DQScales.push_back(scaleData[i]);
if (scaleData[i] != scaleData[0]) if (scaleData[i] != scaleData[0])

View File

@ -76,15 +76,11 @@ private:
class NodeDesc { class NodeDesc {
public: public:
NodeDesc(const NodeConfig& conf, impl_desc_type type): config(conf) { NodeDesc(NodeConfig conf, impl_desc_type type):
implementationType = type; config(std::move(conf)), implementationType(type), executorFactory(nullptr) {}
executorFactory = nullptr;
}
NodeDesc(const NodeConfig& conf, impl_desc_type type, ExecutorFactoryPtr factory): config(conf) { NodeDesc(NodeConfig conf, impl_desc_type type, ExecutorFactoryPtr factory):
implementationType = type; config(std::move(conf)), implementationType(type), executorFactory(factory) {}
executorFactory = factory;
}
const NodeConfig& getConfig() const { const NodeConfig& getConfig() const {
return config; return config;
@ -560,8 +556,8 @@ protected:
virtual PortDescBasePtr getConsistentInputDesc(const NodeConfig &config, size_t idx) const; virtual PortDescBasePtr getConsistentInputDesc(const NodeConfig &config, size_t idx) const;
virtual PortDescBasePtr getConsistentOutputDesc(const NodeConfig &config, size_t idx) const; virtual PortDescBasePtr getConsistentOutputDesc(const NodeConfig &config, size_t idx) const;
virtual MemoryDescPtr getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx); virtual MemoryDescPtr getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const;
virtual MemoryDescPtr getDstMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx); virtual MemoryDescPtr getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const;
virtual AttrPtr initPrimitiveAttr() { return nullptr; } virtual AttrPtr initPrimitiveAttr() { return nullptr; }
@ -574,7 +570,7 @@ protected:
std::vector <NodePtr> fusedWith; std::vector <NodePtr> fusedWith;
std::vector <NodePtr> mergedWith; std::vector <NodePtr> mergedWith;
std::vector <impl_desc_type> implPriorities; std::vector <impl_desc_type> customImplPriorities;
std::vector <dnnl::memory::format_tag> inputMemoryFormatsFilter; std::vector <dnnl::memory::format_tag> inputMemoryFormatsFilter;
std::vector <dnnl::memory::format_tag> outputMemoryFormatsFilter; std::vector <dnnl::memory::format_tag> outputMemoryFormatsFilter;
bool enforceBF16evenForGraphTail = false; bool enforceBF16evenForGraphTail = false;
@ -619,7 +615,11 @@ protected:
bool isConfigDefined(const NodeConfig &config) const; bool isConfigDefined(const NodeConfig &config) const;
virtual bool canBeInPlace() const; virtual bool canBeInPlace() const;
virtual const std::vector<impl_desc_type>& getPrimitivesPriority(); /* returns default implementaion prioirity */
virtual const std::vector<impl_desc_type>& getDefaultImplPriority();
/* returns custom implementation priority + default implementation priority appended as a fallback
* if custom implementaiton priority is not specified, returns default implementation priority */
const std::vector<impl_desc_type>& getImplPriority();
virtual std::vector<dnnl::memory::format_tag> getAvailableFormatsForDims(const Shape& dims) const; virtual std::vector<dnnl::memory::format_tag> getAvailableFormatsForDims(const Shape& dims) const;
@ -724,9 +724,7 @@ private:
// copies of same content with different layouts. // copies of same content with different layouts.
std::unordered_map<std::string, MemoryPtr> privateWeightCache; std::unordered_map<std::string, MemoryPtr> privateWeightCache;
#ifdef CPU_DEBUG_CAPS CPU_DEBUG_CAP_ENABLE(friend class Verbose);
friend class Verbose;
#endif
}; };
template <class... T> template <class... T>

View File

@ -116,9 +116,6 @@ AdaptivePooling::AdaptivePooling(const std::shared_ptr<ngraph::Node>& op, const
} }
void AdaptivePooling::getSupportedDescriptors() { void AdaptivePooling::getSupportedDescriptors() {
if (!descs.empty())
return;
if (getParentEdges().size() != 2) if (getParentEdges().size() != 2)
IE_THROW() << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size(); IE_THROW() << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size();
if (getChildEdges().size() < (algorithm == Algorithm::AdaptivePoolingMax ? 2 : 1)) if (getChildEdges().size() < (algorithm == Algorithm::AdaptivePoolingMax ? 2 : 1))

View File

@ -929,9 +929,6 @@ BinaryConvolution::BinaryConvolution(const std::shared_ptr<ngraph::Node>& op, co
} }
void BinaryConvolution::getSupportedDescriptors() { void BinaryConvolution::getSupportedDescriptors() {
if (!descs.empty())
return;
withBinarization = isFusedWith(Type::FakeQuantize); withBinarization = isFusedWith(Type::FakeQuantize);
withSum = false; withSum = false;
size_t expectedInputEdgesNum = 2; size_t expectedInputEdgesNum = 2;

View File

@ -163,7 +163,7 @@ void Concat::initSupportedPrimitiveDescriptors() {
if (isDynamicNode()) { if (isDynamicNode()) {
config.inConfs[i].setMemDesc(desc); config.inConfs[i].setMemDesc(desc);
} else { } else {
config.inConfs[i].setMemDesc(desc, BLOCKED_DESC_EMPTY_MASK); config.inConfs[i].setMemDesc(desc, BlockedMemoryDesc::EMPTY_MASK);
} }
} }
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref); supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::ref);
@ -197,7 +197,7 @@ void Concat::initSupportedPrimitiveDescriptors() {
SizeVector strides(numOfDim); SizeVector strides(numOfDim);
strides.back() = 1lu; strides.back() = 1lu;
size_t offset = Shape::UNDEFINED_DIM; size_t offset = Shape::UNDEFINED_DIM;
BlockedMemoryDesc::CmpMask mask = BLOCKED_DESC_SKIP_OFFSET_MASK; // any offset BlockedMemoryDesc::CmpMask mask = BlockedMemoryDesc::SKIP_OFFSET_MASK; // any offset
for (size_t i = 2; i <= numOfDim; i++) { for (size_t i = 2; i <= numOfDim; i++) {
if (numOfDim - i < axis) { if (numOfDim - i < axis) {
@ -509,7 +509,7 @@ void Concat::initOptimalPrimitiveDescriptor() {
firstOutBlockingDesc->getOffsetPadding() + offset, firstOutBlockingDesc->getOffsetPadding() + offset,
firstOutBlockingDesc->getOffsetPaddingToData(), firstOutBlockingDesc->getOffsetPaddingToData(),
firstOutBlockingDesc->getStrides()), firstOutBlockingDesc->getStrides()),
BLOCKED_DESC_FULL_MASK); BlockedMemoryDesc::FULL_MASK);
size_t axisSize = 1; size_t axisSize = 1;
auto firstInpBlockingDesc = config.inConfs[0].getMemDesc()->as<BlockedMemoryDesc>(); auto firstInpBlockingDesc = config.inConfs[0].getMemDesc()->as<BlockedMemoryDesc>();

View File

@ -324,8 +324,8 @@ InferenceEngine::Precision Convolution::fusedEltwisePrecision(const NodePtr& fus
return eltwisePrecision; return eltwisePrecision;
} }
const std::vector<impl_desc_type>& Convolution::getPrimitivesPriority() { const std::vector<impl_desc_type>& Convolution::getDefaultImplPriority() {
std::vector<impl_desc_type> priorities = { static const std::vector<impl_desc_type> priorities = {
impl_desc_type::unknown, impl_desc_type::unknown,
impl_desc_type::dw_acl, impl_desc_type::dw_acl,
impl_desc_type::winograd_acl, impl_desc_type::winograd_acl,
@ -363,11 +363,7 @@ const std::vector<impl_desc_type>& Convolution::getPrimitivesPriority() {
impl_desc_type::ref, impl_desc_type::ref,
}; };
for (const auto& impl : priorities) { return priorities;
if (std::find(implPriorities.begin(), implPriorities.end(), impl) == implPriorities.end())
implPriorities.push_back(impl);
}
return implPriorities;
} }
const bool Convolution::isBrgConvAvailable = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core); const bool Convolution::isBrgConvAvailable = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core);
@ -381,10 +377,10 @@ void Convolution::getSupportedDescriptors() {
attrs.reserve(2); attrs.reserve(2);
withBiases = getOriginalInputsNumber() == 3; withBiases = getOriginalInputsNumber() == 3;
if (!implPriorities.empty()) { if (!customImplPriorities.empty()) {
isPrimitivesPriorityDefined = true; isPrimitivesPriorityDefined = true;
// winograd support only constant weights and bias // winograd support only constant weights and bias
isWino = std::find(implPriorities.begin(), implPriorities.end(), impl_desc_type::jit_avx512_winograd) != implPriorities.end() && isWino = std::find(customImplPriorities.begin(), customImplPriorities.end(), impl_desc_type::jit_avx512_winograd) != customImplPriorities.end() &&
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) && !canBeExecutedInInt8() && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) && !canBeExecutedInInt8() &&
getParentEdgeAt(1)->getParent()->isConstant() && getParentEdgeAt(1)->getParent()->getType() == Type::Input && getParentEdgeAt(1)->getParent()->isConstant() && getParentEdgeAt(1)->getParent()->getType() == Type::Input &&
(withBiases ? (getParentEdgeAt(2)->getParent()->isConstant() && getParentEdgeAt(2)->getParent()->getType() == Type::Input) : true); (withBiases ? (getParentEdgeAt(2)->getParent()->isConstant() && getParentEdgeAt(2)->getParent()->getType() == Type::Input) : true);
@ -709,88 +705,79 @@ void Convolution::setPostOps(dnnl::primitive_attr& attr,
} }
void Convolution::selectOptimalPrimitiveDescriptor() { void Convolution::selectOptimalPrimitiveDescriptor() {
selectPreferPrimitiveDescriptor(getPrimitivesPriority(), true); selectPreferPrimitiveDescriptor(getImplPriority(), true);
} }
void Convolution::initSupportedPrimitiveDescriptors() { void Convolution::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty()) if (!supportedPrimitiveDescriptors.empty())
return; return;
bool containJitImpl = false; auto getBlockedMask = [](const std::shared_ptr<MemoryDesc>& memDesc, const bool isGrouped) {
if (memDesc->getType() & MemoryDescType::Blocked && !isGrouped)
return BlockedMemoryDesc::EMPTY_MASK;
return BlockedMemoryDesc::FULL_MASK;
};
for (size_t dIdx = 0; dIdx < descs.size(); dIdx++) { auto addSupportedPrimitiveDescriptor = [&](const dnnl::primitive_desc& prim_desc) {
const auto& desc = descs[dIdx]; std::vector<PortConfig> inConfs, outConfs;
const int inPlaceOutPort = withSum ? static_cast<int>(getParentEdges().size()) - 1 : -1;
if (containJitImpl && isPossibleToSkipInitConfig(desc))
continue;
auto itpd = desc;
while (itpd) {
NodeConfig config;
for (size_t i = 0; i < descInputNumbers(); i++) { for (size_t i = 0; i < descInputNumbers(); i++) {
PortConfig dataConfig; auto desc = getSrcMemDesc(prim_desc, i);
dataConfig.inPlace(-1);
dataConfig.constant(false);
auto desc = getSrcMemDesc(itpd, i);
if (desc->getType() & MemoryDescType::Blocked && !isGrouped) {
dataConfig.setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), BLOCKED_DESC_EMPTY_MASK);
} else {
dataConfig.setMemDesc(std::move(desc));
}
config.inConfs.push_back(dataConfig); inConfs.emplace_back(desc, getBlockedMask(desc, isGrouped));
} }
if (withDWConv) { if (withDWConv) {
auto weightsPrc = DnnlExtensionUtils::IEPrecisionToDataType(dw_conv_in_dt == dnnl_u8 ? Precision::I8 : Precision::FP32); const std::vector<size_t> dwWeightsDims{dw_conv_oc, 1, 1, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS]};
auto biasPrc = memory::data_type::f32; const std::vector<size_t> dwBiasesDims{dw_conv_oc};
std::vector<size_t> dwWeightsDims({dw_conv_oc, 1, 1, dw_conv_kernel[Y_AXIS], dw_conv_kernel[X_AXIS]}); const auto dwWeightsPrc = DnnlExtensionUtils::IEPrecisionToDataType(dw_conv_in_dt == dnnl_u8 ? Precision::I8 : Precision::FP32);
std::vector<size_t> dwBiasesDims({dw_conv_oc}); const auto dwWeightsDesc = std::make_shared<DnnlBlockedMemoryDesc>(Shape(dwWeightsDims), dwWeightsPrc, memory::format_tag::Goihw8g);
inConfs.emplace_back(dwWeightsDesc);
PortConfig dataConfig; const auto dwBiasPrc = memory::data_type::f32;
dataConfig.inPlace(-1); const auto dwBiasDesc = std::make_shared<DnnlBlockedMemoryDesc>(Shape(dwBiasesDims), dwBiasPrc, memory::format_tag::x);
dataConfig.constant(false); inConfs.emplace_back(dwBiasDesc);
dataConfig.setMemDesc(std::make_shared<DnnlBlockedMemoryDesc>(Shape(dwWeightsDims), weightsPrc, memory::format_tag::Goihw8g));
config.inConfs.push_back(dataConfig);
dataConfig.setMemDesc(std::make_shared<DnnlBlockedMemoryDesc>(Shape(dwBiasesDims), biasPrc, memory::format_tag::x));
config.inConfs.push_back(dataConfig);
} }
for (size_t i = 0; i < descOutputNumbers(); i++) { for (size_t i = 0; i < descOutputNumbers(); i++) {
PortConfig dataConfig; auto desc = getDstMemDesc(prim_desc, i);
if (withSum) {
dataConfig.inPlace(getParentEdges().size() - 1);
}
dataConfig.constant(false); outConfs.emplace_back(desc, getBlockedMask(desc, isGrouped), inPlaceOutPort);
auto desc = getDstMemDesc(itpd, i);
if (desc->getType() & MemoryDescType::Blocked && !isGrouped) {
dataConfig.setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), BLOCKED_DESC_EMPTY_MASK);
} else {
dataConfig.setMemDesc(std::move(desc));
} }
config.outConfs.push_back(dataConfig);
if (withSum) { if (withSum) {
dataConfig.inPlace(-1); const auto outputPrecision = outConfs.back().getMemDesc()->getPrecision();
dataConfig.setMemDesc(getSumMemDesc(itpd)->cloneWithNewPrecision(dataConfig.getMemDesc()->getPrecision())); const auto sumDesc = getSumMemDesc(prim_desc)->cloneWithNewPrecision(outputPrecision);
config.inConfs.push_back(dataConfig); inConfs.emplace_back(sumDesc);
} }
}
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str()); NodeConfig config(inConfs, outConfs);
if (impl_type & jit) const impl_desc_type impl_type = parse_impl_name(prim_desc.impl_info_str());
containJitImpl = true;
supportedPrimitiveDescriptors.emplace_back(config, impl_type); supportedPrimitiveDescriptors.emplace_back(config, impl_type);
descIdx.push_back(dIdx); };
if (!itpd.next_impl()) for (size_t dIdx = 0; dIdx < descs.size(); dIdx++) {
break; auto& desc = descs[dIdx];
} auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get()));
const bool first_match = customImplPriorities.empty();
DnnlExtensionUtils::for_each_implementation(desc,
first_match,
[&](impl_desc_type implType) {
return contains(getImplPriority(), implType);
},
[&](dnnl::primitive_desc& desc) {
addSupportedPrimitiveDescriptor(desc);
descIdx.push_back(dIdx);
});
// fallback. if none of the primitive types is present in the priority list just add first implementation
// @todo this fallback is not necessary if primitive priority list is filled correctly
if (supportedPrimitiveDescriptors.empty())
addSupportedPrimitiveDescriptor(first_desc);
} }
} }
@ -894,7 +881,7 @@ void Convolution::createDescriptor(const std::vector<MemoryDescPtr>& inputDesc,
const auto desc = createDescriptorInternal(getEngine(), const auto desc = createDescriptorInternal(getEngine(),
inDnnlDesc, weightDnnlDesc, biasDnnlDesc, outDnnlDesc, withBiases, inDnnlDesc, weightDnnlDesc, biasDnnlDesc, outDnnlDesc, withBiases,
stride, dilation, paddingL, paddingR, alg, attr); stride, dilation, paddingL, paddingR, alg, attr);
if (desc.get(true)) if (desc)
descs.emplace_back(desc); descs.emplace_back(desc);
} }
} }
@ -1101,47 +1088,13 @@ void Convolution::filterSupportedDescriptors() {
descs.erase(std::remove_if(descs.begin(), descs.end(), isNotSuitableDesc), descs.end()); descs.erase(std::remove_if(descs.begin(), descs.end(), isNotSuitableDesc), descs.end());
} }
bool Convolution::isPossibleToSkipInitConfig(const dnnl::primitive_desc &desc) const { std::shared_ptr<MemoryDesc> Convolution::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
// WA: In some cases, we can predict in advance the type of primitive that will be called in the future.
// In particular, isPossibleToSkipInitConfig() checks whether we can skip the creation of primitives with
// gemm implementation, which significantly increase the network load time.
if (!inputMemoryFormatsFilter.empty() || !outputMemoryFormatsFilter.empty())
return false;
if (isPrimitivesPriorityDefined)
return false;
// Here we check that we will not delete jit_planar_conv primitive by mistake.
// It requires:
// 1) strides equal 1;
// 2) not grouped;
// 3) first dim of weights is not 1.
bool isPossibleJitPlanar = true;
if (isGrouped || weightDims[0] != 1)
isPossibleJitPlanar = false;
if (std::any_of(stride.begin(), stride.end(), [](const size_t s) { return s != 1; }))
isPossibleJitPlanar = false;
auto srcMemDesc = DnnlExtensionUtils::makeDescriptor(desc.src_desc());
auto dstMemDesc = DnnlExtensionUtils::makeDescriptor(desc.dst_desc());
auto srcDataType = srcMemDesc->getDataType();
auto dstDataType = dstMemDesc->getDataType();
bool isPlanarFloatConv = srcMemDesc->hasLayoutType(LayoutType::ncsp)
&& dstMemDesc->hasLayoutType(LayoutType::ncsp)
&& srcDataType == memory::data_type::f32
&& dstDataType == memory::data_type::f32;
return !isPossibleJitPlanar && isPlanarFloatConv;
}
std::shared_ptr<MemoryDesc> Convolution::getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) {
if (idx == 1) { if (idx == 1) {
// report original plain layout for weight since it needs to be reordered dynamically at runtime // report original plain layout for weight since it needs to be reordered dynamically at runtime
return std::make_shared<CpuBlockedMemoryDesc>(getOriginalInputPrecisionAtPort(idx), return std::make_shared<CpuBlockedMemoryDesc>(getOriginalInputPrecisionAtPort(idx),
Shape(getInputShapeAtPort(idx).getStaticDims())); Shape(getInputShapeAtPort(idx).getStaticDims()));
} }
auto desc = idx > 0 ? primitive_desc_it.weights_desc(idx - 1) : primitive_desc_it.src_desc(idx); auto desc = idx > 0 ? prim_desc.weights_desc(idx - 1) : prim_desc.src_desc(idx);
if (getInputShapeAtPort(idx).isDynamic()) { if (getInputShapeAtPort(idx).isDynamic()) {
return DnnlExtensionUtils::makeUndefinedDesc(desc, getInputShapeAtPort(idx)); return DnnlExtensionUtils::makeUndefinedDesc(desc, getInputShapeAtPort(idx));
} }
@ -1399,6 +1352,7 @@ void Convolution::prepareParams() {
key.attr); key.attr);
const bool found = DnnlExtensionUtils::find_implementation(prim_desc, key.implType); const bool found = DnnlExtensionUtils::find_implementation(prim_desc, key.implType);
if (found) { if (found) {
return std::make_shared<ConvolutionExecutor>( return std::make_shared<ConvolutionExecutor>(
prim_desc, prim_desc,
@ -1572,7 +1526,7 @@ void Convolution::redefineOutputMemory(const std::vector<VectorDims> &newOutputS
Node::redefineOutputMemory(newOutputShapes); Node::redefineOutputMemory(newOutputShapes);
} }
MemoryDescPtr Convolution::getSumMemDesc(primitive_desc_iterator &primitive_desc_it) { MemoryDescPtr Convolution::getSumMemDesc(const primitive_desc &primitive_desc_it) {
if (getOutputShapeAtPort(0).isDynamic()) { if (getOutputShapeAtPort(0).isDynamic()) {
return DnnlExtensionUtils::makeUndefinedDesc(primitive_desc_it.dst_desc(0), getOutputShapeAtPort(0)); return DnnlExtensionUtils::makeUndefinedDesc(primitive_desc_it.dst_desc(0), getOutputShapeAtPort(0));
} }

View File

@ -34,7 +34,7 @@ public:
return false; return false;
} }
InferenceEngine::Precision getRuntimePrecision() const override; InferenceEngine::Precision getRuntimePrecision() const override;
std::shared_ptr<MemoryDesc> getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override; std::shared_ptr<MemoryDesc> getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const override;
dnnl::memory getWeights() const; dnnl::memory getWeights() const;
dnnl::memory getBias() const; dnnl::memory getBias() const;
@ -73,7 +73,7 @@ protected:
InferenceEngine::Precision fusedEltwisePrecision(const NodePtr& fusingNode) const; InferenceEngine::Precision fusedEltwisePrecision(const NodePtr& fusingNode) const;
void redefineOutputMemory(const std::vector<VectorDims> &newOutputShapes) override; void redefineOutputMemory(const std::vector<VectorDims> &newOutputShapes) override;
void addFusedNode(const NodePtr &fusingNode) override; void addFusedNode(const NodePtr &fusingNode) override;
const std::vector<impl_desc_type>& getPrimitivesPriority() override; const std::vector<impl_desc_type>& getDefaultImplPriority() override;
private: private:
enum class zpType { enum class zpType {
@ -105,12 +105,11 @@ private:
void setPostOps(dnnl::primitive_attr &attr, const VectorDims &dims, bool useLegacyPostOps, bool initWeights = false); void setPostOps(dnnl::primitive_attr &attr, const VectorDims &dims, bool useLegacyPostOps, bool initWeights = false);
void SetPostOpsAndZeroPoints(std::vector<dnnl::primitive_attr> &attrs); void SetPostOpsAndZeroPoints(std::vector<dnnl::primitive_attr> &attrs);
void filterSupportedDescriptors(); void filterSupportedDescriptors();
bool isPossibleToSkipInitConfig(const dnnl::primitive_desc &desc) const;
bool isNspcAvailable() const; bool isNspcAvailable() const;
InferenceEngine::Blob::Ptr createInternalBlob(InferenceEngine::SizeVector dims, size_t edgeNum, bool isGrouped = false); InferenceEngine::Blob::Ptr createInternalBlob(InferenceEngine::SizeVector dims, size_t edgeNum, bool isGrouped = false);
void updatePadding(); void updatePadding();
MemoryDescPtr getSumMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it); MemoryDescPtr getSumMemDesc(const dnnl::primitive_desc &primitive_desc_it);
MemoryPtr getOutputMemory() const; MemoryPtr getOutputMemory() const;
VectorDims makeInputDummyShape(const Shape& inpShape) const; VectorDims makeInputDummyShape(const Shape& inpShape) const;
VectorDims outputStaticShape() const; VectorDims outputStaticShape() const;

View File

@ -1077,7 +1077,7 @@ void Deconvolution::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc
} }
} }
std::shared_ptr<MemoryDesc> Deconvolution::getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) { std::shared_ptr<MemoryDesc> Deconvolution::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
if (idx == 2 && !withBiases) { if (idx == 2 && !withBiases) {
return std::make_shared<CpuBlockedMemoryDesc>(InferenceEngine::Precision::I32, Shape(getInputShapeAtPort(2).getStaticDims())); return std::make_shared<CpuBlockedMemoryDesc>(InferenceEngine::Precision::I32, Shape(getInputShapeAtPort(2).getStaticDims()));
} else if (idx > 0 && isInt8) { } else if (idx > 0 && isInt8) {
@ -1086,15 +1086,15 @@ std::shared_ptr<MemoryDesc> Deconvolution::getSrcMemDesc(dnnl::primitive_desc_it
return std::make_shared<CpuBlockedMemoryDesc>(getOriginalInputPrecisionAtPort(idx), Shape(getInputShapeAtPort(idx).getStaticDims())); return std::make_shared<CpuBlockedMemoryDesc>(getOriginalInputPrecisionAtPort(idx), Shape(getInputShapeAtPort(idx).getStaticDims()));
} }
auto desc = idx > 0 ? primitive_desc_it.weights_desc(idx - 1) : isInt8 ? primitive_desc_it.src_desc(idx) : primitive_desc_it.diff_dst_desc(idx); auto desc = idx > 0 ? prim_desc.weights_desc(idx - 1) : isInt8 ? prim_desc.src_desc(idx) : prim_desc.diff_dst_desc(idx);
if (getInputShapeAtPort(idx).isDynamic()) { if (getInputShapeAtPort(idx).isDynamic()) {
return DnnlExtensionUtils::makeUndefinedDesc(desc, getInputShapeAtPort(idx)); return DnnlExtensionUtils::makeUndefinedDesc(desc, getInputShapeAtPort(idx));
} }
return DnnlExtensionUtils::makeDescriptor(desc); return DnnlExtensionUtils::makeDescriptor(desc);
} }
std::shared_ptr<MemoryDesc> Deconvolution::getDstMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) { std::shared_ptr<MemoryDesc> Deconvolution::getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
auto desc = isInt8 ? primitive_desc_it.dst_desc(idx) : primitive_desc_it.diff_src_desc(idx); auto desc = isInt8 ? prim_desc.dst_desc(idx) : prim_desc.diff_src_desc(idx);
if (getOutputShapeAtPort(idx).isDynamic()) { if (getOutputShapeAtPort(idx).isDynamic()) {
return DnnlExtensionUtils::makeUndefinedDesc(desc, getOutputShapeAtPort(idx)); return DnnlExtensionUtils::makeUndefinedDesc(desc, getOutputShapeAtPort(idx));
} }

View File

@ -34,8 +34,8 @@ public:
return static_cast<size_t>(getParentEdges().size()); return static_cast<size_t>(getParentEdges().size());
} }
std::shared_ptr<MemoryDesc> getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override; std::shared_ptr<MemoryDesc> getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const override;
std::shared_ptr<MemoryDesc> getDstMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override; std::shared_ptr<MemoryDesc> getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const override;
InferenceEngine::Precision getRuntimePrecision() const override; InferenceEngine::Precision getRuntimePrecision() const override;

View File

@ -2049,7 +2049,7 @@ void Eltwise::initSupportedPrimitiveDescriptors() {
NodeConfig config; NodeConfig config;
for (size_t i = 0; i < getParentEdges().size(); i++) { for (size_t i = 0; i < getParentEdges().size(); i++) {
BlockedMemoryDesc::CmpMask inputMask = BLOCKED_DESC_SKIP_OFFSET_MASK; BlockedMemoryDesc::CmpMask inputMask = BlockedMemoryDesc::SKIP_OFFSET_MASK;
PortConfig portConfig; PortConfig portConfig;
// TODO [DS]: inplace // TODO [DS]: inplace
if (!isDynamicNode()) if (!isDynamicNode())
@ -2070,7 +2070,7 @@ void Eltwise::initSupportedPrimitiveDescriptors() {
portConfig.constant(false); portConfig.constant(false);
const auto &dstShape = getOutputShapeAtPort(0); const auto &dstShape = getOutputShapeAtPort(0);
BlockedMemoryDesc::CmpMask outputMask = BLOCKED_DESC_SKIP_OFFSET_MASK; BlockedMemoryDesc::CmpMask outputMask = BlockedMemoryDesc::SKIP_OFFSET_MASK;
if (!isDynamicNode() && dstShape.getDims()[0] == 1) { if (!isDynamicNode() && dstShape.getDims()[0] == 1) {
outputMask.reset(0); // accepts any stride on the batch axis outputMask.reset(0); // accepts any stride on the batch axis
} }
@ -2091,7 +2091,7 @@ void Eltwise::initSupportedPrimitiveDescriptors() {
} }
auto factory = std::make_shared<EltwiseExecutorFactory>(eltwiseAttrs, srcMemoryDescs, dstMemoryDescs, auto factory = std::make_shared<EltwiseExecutorFactory>(eltwiseAttrs, srcMemoryDescs, dstMemoryDescs,
std::make_shared<ExecutorContext>(context, getPrimitivesPriority())); std::make_shared<ExecutorContext>(context, getImplPriority()));
return {config, impl_type, !factory->isEmpty() ? factory : nullptr}; return {config, impl_type, !factory->isEmpty() ? factory : nullptr};
} else { } else {
@ -2332,7 +2332,7 @@ bool Eltwise::needPrepareParams() const {
} }
void Eltwise::selectOptimalPrimitiveDescriptor() { void Eltwise::selectOptimalPrimitiveDescriptor() {
selectPreferPrimitiveDescriptor(getPrimitivesPriority(), true); selectPreferPrimitiveDescriptor(getImplPriority(), true);
} }
void Eltwise::execute(dnnl::stream strm) { void Eltwise::execute(dnnl::stream strm) {

View File

@ -48,12 +48,11 @@ public:
typedef std::shared_ptr<ExecutorContext> Ptr; typedef std::shared_ptr<ExecutorContext> Ptr;
typedef std::shared_ptr<const ExecutorContext> CPtr; typedef std::shared_ptr<const ExecutorContext> CPtr;
ExecutorContext(const GraphContext::CPtr graphContext, const std::vector<impl_desc_type>& implPriorities) { ExecutorContext(const GraphContext::CPtr graphContext, const std::vector<impl_desc_type>& implPriorities)
this->runtimeCache = graphContext->getParamsCache(); : runtimeCache(graphContext->getParamsCache()),
this->scratchPad = graphContext->getScratchPad(); scratchPad(graphContext->getScratchPad()),
this->engine = graphContext->getEngine(); engine(graphContext->getEngine()),
this->implPriorities = implPriorities; implPriorities(implPriorities) {}
}
MultiCacheWeakPtr getRuntimeCache() const { MultiCacheWeakPtr getRuntimeCache() const {
return runtimeCache; return runtimeCache;
@ -75,9 +74,9 @@ private:
// weak_ptr is required to avoid cycle dependencies with MultiCache // weak_ptr is required to avoid cycle dependencies with MultiCache
// since ExecutorContext is stored in Executor itself // since ExecutorContext is stored in Executor itself
MultiCacheWeakPtr runtimeCache; MultiCacheWeakPtr runtimeCache;
DnnlScratchPadPtr scratchPad = nullptr; DnnlScratchPadPtr scratchPad;
dnnl::engine engine; dnnl::engine engine;
std::vector<impl_desc_type> implPriorities = {}; std::vector<impl_desc_type> implPriorities;
}; };
class ExecutorFactory { class ExecutorFactory {

View File

@ -62,8 +62,6 @@ Eye::Eye(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr context)
} }
void Eye::getSupportedDescriptors() { void Eye::getSupportedDescriptors() {
if (!descs.empty())
return;
if (!one_of(getParentEdges().size(), 3u, 4u)) if (!one_of(getParentEdges().size(), 3u, 4u))
THROW_ERROR << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size(); THROW_ERROR << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size();
if (getChildEdges().empty()) if (getChildEdges().empty())

View File

@ -8,6 +8,7 @@
#include "input.h" #include "input.h"
#include "fake_quantize.h" #include "fake_quantize.h"
#include "input.h" #include "input.h"
#include "memory_desc/blocked_memory_desc.h"
#include "reorder.h" #include "reorder.h"
#include "transformations/cpu_opset/common/op/fully_connected.hpp" #include "transformations/cpu_opset/common/op/fully_connected.hpp"
#include "ngraph/opsets/opset1.hpp" #include "ngraph/opsets/opset1.hpp"
@ -505,8 +506,8 @@ bool FullyConnected::created() const {
return getType() == Type::FullyConnected; return getType() == Type::FullyConnected;
} }
const std::vector<impl_desc_type>& FullyConnected::getPrimitivesPriority() { const std::vector<impl_desc_type>& FullyConnected::getDefaultImplPriority() {
std::vector<impl_desc_type> priorities = { static const std::vector<impl_desc_type> priorities = {
impl_desc_type::unknown, impl_desc_type::unknown,
impl_desc_type::acl, impl_desc_type::acl,
impl_desc_type::brgemm_sparse_avx512_amx, impl_desc_type::brgemm_sparse_avx512_amx,
@ -538,11 +539,7 @@ const std::vector<impl_desc_type>& FullyConnected::getPrimitivesPriority() {
impl_desc_type::ref, impl_desc_type::ref,
}; };
for (const auto& impl : priorities) { return priorities;
if (std::find(implPriorities.begin(), implPriorities.end(), impl) == implPriorities.end())
implPriorities.push_back(impl);
}
return implPriorities;
} }
// WA: creation DnnlMemoryDesc with format == any is prohibited // WA: creation DnnlMemoryDesc with format == any is prohibited
@ -639,53 +636,60 @@ void FullyConnected::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty()) if (!supportedPrimitiveDescriptors.empty())
return; return;
for (auto& desc : descs) {
primitive_desc_iterator itpd = desc;
while (static_cast<bool>(itpd)) {
// 3D FC requires implicit reshape so strides should be defined // 3D FC requires implicit reshape so strides should be defined
auto supportsUndefStridesAndOffset = [&]() { auto supportsUndefStridesAndOffset = [&]() {
return getOutputShapeAtPort(0).getRank() == 2; return getOutputShapeAtPort(0).getRank() == 2;
}; };
NodeConfig config; auto addSupportedPrimitiveDescriptor = [&](const dnnl::primitive_desc& prim_desc) {
std::vector<PortConfig> inConfs, outConfs;
const int inPlaceOutPort = canBeInPlace() ? 0 : -1;
for (size_t i = 0; i < descInputNumbers(); i++) { for (size_t i = 0; i < descInputNumbers(); i++) {
PortConfig portConfig; auto desc = getSrcMemDesc(prim_desc, i);
portConfig.inPlace(-1); const auto inputBlockedMask = (supportsUndefStridesAndOffset() && !(i == WEIGHTS_ID && useSparseWeights)) ?
portConfig.constant(false); BlockedMemoryDesc::EMPTY_MASK :
auto desc = getSrcMemDesc(itpd, i); BlockedMemoryDesc::FULL_MASK;
if (supportsUndefStridesAndOffset() && !(i == WEIGHTS_ID && useSparseWeights)) {
portConfig.setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), BLOCKED_DESC_EMPTY_MASK); inConfs.emplace_back(desc, inputBlockedMask);
} else {
portConfig.setMemDesc(desc);
}
config.inConfs.push_back(portConfig);
} }
const auto outputBlockedMask = supportsUndefStridesAndOffset() ? BlockedMemoryDesc::EMPTY_MASK : BlockedMemoryDesc::FULL_MASK;
for (size_t i = 0; i < descOutputNumbers(); i++) { for (size_t i = 0; i < descOutputNumbers(); i++) {
PortConfig portConfig; auto desc = getDstMemDesc(prim_desc, i);
portConfig.inPlace(canBeInPlace() ? 0 : -1);
portConfig.constant(false); outConfs.emplace_back(desc, outputBlockedMask, inPlaceOutPort);
auto desc = getDstMemDesc(itpd, i);
if (supportsUndefStridesAndOffset()) {
portConfig.setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), BLOCKED_DESC_EMPTY_MASK);
} else {
portConfig.setMemDesc(desc);
}
config.outConfs.push_back(portConfig);
} }
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str()); const NodeConfig config(inConfs, outConfs);
const impl_desc_type impl_type = parse_impl_name(prim_desc.impl_info_str());
supportedPrimitiveDescriptors.emplace_back(config, impl_type); supportedPrimitiveDescriptors.emplace_back(config, impl_type);
};
if (!itpd.next_impl()) for (auto& desc : descs) {
break; auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get()));
}
const bool first_match = customImplPriorities.empty();
DnnlExtensionUtils::for_each_implementation(desc,
first_match,
[&](impl_desc_type implType) {
return contains(getImplPriority(), implType);
},
[&](dnnl::primitive_desc& desc) {
addSupportedPrimitiveDescriptor(desc);
});
// fallback. if none of the primitive types is present in the priority list just add first implementation
// @todo this fallback is not necessary if primitive priority list is filled correctly
if (supportedPrimitiveDescriptors.empty())
addSupportedPrimitiveDescriptor(first_desc);
} }
} }
std::shared_ptr<MemoryDesc> FullyConnected::getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) { std::shared_ptr<MemoryDesc> FullyConnected::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
auto desc = idx > 0 ? primitive_desc_it.weights_desc(idx - 1) : primitive_desc_it.src_desc(idx); auto desc = idx > 0 ? prim_desc.weights_desc(idx - 1) : prim_desc.src_desc(idx);
if (getInputShapeAtPort(idx).getRank() == 3) { if (getInputShapeAtPort(idx).getRank() == 3) {
return std::make_shared<CpuBlockedMemoryDesc>( return std::make_shared<CpuBlockedMemoryDesc>(
@ -699,8 +703,8 @@ std::shared_ptr<MemoryDesc> FullyConnected::getSrcMemDesc(dnnl::primitive_desc_i
return DnnlExtensionUtils::makeDescriptor(desc); return DnnlExtensionUtils::makeDescriptor(desc);
} }
std::shared_ptr<MemoryDesc> FullyConnected::getDstMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) { std::shared_ptr<MemoryDesc> FullyConnected::getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
auto desc = primitive_desc_it.dst_desc(idx); auto desc = prim_desc.dst_desc(idx);
if (getOutputShapeAtPort(idx).getRank() == 3) { if (getOutputShapeAtPort(idx).getRank() == 3) {
return std::make_shared<CpuBlockedMemoryDesc>( return std::make_shared<CpuBlockedMemoryDesc>(

View File

@ -33,7 +33,7 @@ public:
return getOutputShapeAtPort(0).getRank() == 3 ? 2 : 1; return getOutputShapeAtPort(0).getRank() == 3 ? 2 : 1;
} }
const std::vector<impl_desc_type>& getPrimitivesPriority() override; const std::vector<impl_desc_type>& getDefaultImplPriority() override;
void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc, void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc,
const std::vector<MemoryDescPtr>& outputDesc) override; const std::vector<MemoryDescPtr>& outputDesc) override;
@ -44,8 +44,8 @@ public:
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;
void initOptimalPrimitiveDescriptor() override; void initOptimalPrimitiveDescriptor() override;
void createPrimitive() override; void createPrimitive() override;
std::shared_ptr<MemoryDesc> getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override; std::shared_ptr<MemoryDesc> getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const override;
std::shared_ptr<MemoryDesc> getDstMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override; std::shared_ptr<MemoryDesc> getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const override;
InferenceEngine::Precision getRuntimePrecision() const override; InferenceEngine::Precision getRuntimePrecision() const override;

View File

@ -2096,7 +2096,7 @@ void Interpolate::initSupportedPrimitiveDescriptors() {
} }
auto factory = std::make_shared<InterpolateExecutorFactory>(interpAttrs, srcMemoryDescs, dstMemoryDescs, auto factory = std::make_shared<InterpolateExecutorFactory>(interpAttrs, srcMemoryDescs, dstMemoryDescs,
std::make_shared<ExecutorContext>(context, getPrimitivesPriority())); std::make_shared<ExecutorContext>(context, getImplPriority()));
if (!factory->isEmpty()) { if (!factory->isEmpty()) {
supportedPrimitiveDescriptors.push_back({config, implDetail, factory}); supportedPrimitiveDescriptors.push_back({config, implDetail, factory});
} }

View File

@ -151,14 +151,14 @@ void Lrn::getSupportedDescriptors() {
} }
} }
std::shared_ptr<MemoryDesc> Lrn::getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) { std::shared_ptr<MemoryDesc> Lrn::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
if (idx > 0) { if (idx > 0) {
return std::make_shared<CpuBlockedMemoryDesc>(getOriginalInputPrecisionAtPort(idx), getInputShapeAtPort(idx)); return std::make_shared<CpuBlockedMemoryDesc>(getOriginalInputPrecisionAtPort(idx), getInputShapeAtPort(idx));
} else { } else {
if (getInputShapeAtPort(idx).isDynamic()) { if (getInputShapeAtPort(idx).isDynamic()) {
return DnnlExtensionUtils::makeUndefinedDesc(primitive_desc_it.src_desc(idx), getInputShapeAtPort(idx)); return DnnlExtensionUtils::makeUndefinedDesc(prim_desc.src_desc(idx), getInputShapeAtPort(idx));
} }
return DnnlExtensionUtils::makeDescriptor(primitive_desc_it.src_desc(idx)); return DnnlExtensionUtils::makeDescriptor(prim_desc.src_desc(idx));
} }
} }

View File

@ -25,7 +25,7 @@ public:
size_t descInputNumbers() override { size_t descInputNumbers() override {
return static_cast<size_t>(getOriginalInputsNumber()); return static_cast<size_t>(getOriginalInputsNumber());
} }
std::shared_ptr<MemoryDesc> getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override; std::shared_ptr<MemoryDesc> getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const override;
bool created() const override; bool created() const override;
bool canBeInPlace() const override { bool canBeInPlace() const override {
return false; return false;

View File

@ -506,39 +506,50 @@ void MatMul::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty()) if (!supportedPrimitiveDescriptors.empty())
return; return;
for (auto& desc : descs) { auto addSupportedPrimitiveDescriptor = [&](const dnnl::primitive_desc& prim_desc) {
auto itpd = desc; std::vector<PortConfig> inConfs, outConfs;
while (itpd) { const int inPlaceOutPort = canBeInPlace() ? 0 : -1;
NodeConfig config;
for (size_t i = 0; i < descInputNumbers(); i++) {
PortConfig portConfig;
portConfig.inPlace(-1);
portConfig.constant(false);
portConfig.setMemDesc(getSrcMemDesc(itpd, i));
config.inConfs.push_back(portConfig); for (size_t i = 0; i < descInputNumbers(); i++) {
auto desc = getSrcMemDesc(prim_desc, i);
inConfs.emplace_back(desc);
} }
for (size_t i = 0; i < descOutputNumbers(); i++) { for (size_t i = 0; i < descOutputNumbers(); i++) {
PortConfig portConfig; auto desc = getDstMemDesc(prim_desc, i);
portConfig.inPlace(canBeInPlace() ? 0 : -1);
portConfig.constant(false);
portConfig.setMemDesc(getDstMemDesc(itpd, i));
config.outConfs.push_back(portConfig); outConfs.emplace_back(desc, BlockedMemoryDesc::FULL_MASK, inPlaceOutPort);
} }
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str()); const NodeConfig config(inConfs, outConfs);
const impl_desc_type impl_type = parse_impl_name(prim_desc.impl_info_str());
supportedPrimitiveDescriptors.emplace_back(config, impl_type); supportedPrimitiveDescriptors.emplace_back(config, impl_type);
if (!itpd.next_impl()) };
break;
} for (auto& desc : descs) {
auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get()));
const bool first_match = customImplPriorities.empty();
DnnlExtensionUtils::for_each_implementation(desc,
first_match,
[&](impl_desc_type implType) {
return contains(getImplPriority(), implType);
},
[&](dnnl::primitive_desc& desc) {
addSupportedPrimitiveDescriptor(desc);
});
// fallback. if none of the primitive types is present in the priority list just add first implementation
// @todo this fallback is not necessary if primitive priority list is filled correctly
if (supportedPrimitiveDescriptors.empty())
addSupportedPrimitiveDescriptor(first_desc);
} }
} }
MemoryDescPtr MatMul::getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) { MemoryDescPtr MatMul::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
auto desc = idx > 0 ? primitive_desc_it.weights_desc(idx - 1): primitive_desc_it.src_desc(idx); auto desc = idx > 0 ? prim_desc.weights_desc(idx - 1): prim_desc.src_desc(idx);
if (idx < 2) // inputs if (idx < 2) // inputs
return std::make_shared<CpuBlockedMemoryDesc>( return std::make_shared<CpuBlockedMemoryDesc>(
@ -679,8 +690,8 @@ void MatMul::executeDynamicImpl(dnnl::stream strm) {
execute(strm); execute(strm);
} }
const std::vector<impl_desc_type>& MatMul::getPrimitivesPriority() { const std::vector<impl_desc_type>& MatMul::getDefaultImplPriority() {
std::vector<impl_desc_type> priorities = { static const std::vector<impl_desc_type> priorities = {
impl_desc_type::unknown, impl_desc_type::unknown,
impl_desc_type::brgemm_avx512_amx, impl_desc_type::brgemm_avx512_amx,
impl_desc_type::brgemm_avx512, impl_desc_type::brgemm_avx512,
@ -710,11 +721,8 @@ const std::vector<impl_desc_type>& MatMul::getPrimitivesPriority() {
impl_desc_type::jit_sse42, impl_desc_type::jit_sse42,
impl_desc_type::ref, impl_desc_type::ref,
}; };
for (const auto& impl : priorities) {
if (std::find(implPriorities.begin(), implPriorities.end(), impl) == implPriorities.end()) return priorities;
implPriorities.push_back(impl);
}
return implPriorities;
} }
} // namespace node } // namespace node
} // namespace intel_cpu } // namespace intel_cpu

View File

@ -24,7 +24,7 @@ public:
void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc, void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc,
const std::vector<MemoryDescPtr>& outputDesc) override; const std::vector<MemoryDescPtr>& outputDesc) override;
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;
MemoryDescPtr getSrcMemDesc(dnnl::primitive_desc_iterator &primitive_desc_it, size_t idx) override; MemoryDescPtr getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const override;
bool canFuse(const NodePtr& node) const override; bool canFuse(const NodePtr& node) const override;
bool created() const override; bool created() const override;
@ -42,7 +42,7 @@ public:
void executeDynamicImpl(dnnl::stream strm) override; void executeDynamicImpl(dnnl::stream strm) override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept; static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
const std::vector<impl_desc_type>& getPrimitivesPriority() override; const std::vector<impl_desc_type>& getDefaultImplPriority() override;
protected: protected:
AttrPtr initPrimitiveAttr() override; AttrPtr initPrimitiveAttr() override;

View File

@ -1219,7 +1219,7 @@ void MVN::initSupportedPrimitiveDescriptors() {
} }
auto factory = std::make_shared<MVNExecutorFactory>(mvnAttrs, srcMemoryDescs, dstMemoryDescs, auto factory = std::make_shared<MVNExecutorFactory>(mvnAttrs, srcMemoryDescs, dstMemoryDescs,
std::make_shared<ExecutorContext>(context, getPrimitivesPriority())); std::make_shared<ExecutorContext>(context, getImplPriority()));
if (!factory->isEmpty()) { if (!factory->isEmpty()) {
supportedPrimitiveDescriptors.push_back({config, impl_type, factory}); supportedPrimitiveDescriptors.push_back({config, impl_type, factory});
} }

View File

@ -4,6 +4,7 @@
#pragma once #pragma once
#include <memory>
#include "memory_desc/cpu_memory_desc.h" #include "memory_desc/cpu_memory_desc.h"
#include "memory_desc/blocked_memory_desc.h" #include "memory_desc/blocked_memory_desc.h"
@ -77,29 +78,30 @@ public:
private: private:
BlockedMemoryDescPtr _memDesc; BlockedMemoryDescPtr _memDesc;
CmpMask _cmpMask = BLOCKED_DESC_FULL_MASK; CmpMask _cmpMask = BlockedMemoryDesc::FULL_MASK;
}; };
class PortConfig { class PortConfig {
public: public:
PortConfig() = default; PortConfig() = default;
PortConfig(const PortConfig& rhs) { PortConfig(MemoryDescPtr desc,
this->_constant = rhs._constant; BlockedMemoryDesc::CmpMask cmpMask = BlockedMemoryDesc::FULL_MASK,
this->_inPlacePort = rhs._inPlacePort; int inPlacePort = -1,
if (rhs._desc) { bool isConstant = false)
this->_desc = rhs._desc; : _desc(createPortDesc(desc, cmpMask)),
} _inPlacePort(inPlacePort),
} _constant(isConstant) {}
PortConfig& operator=(const PortConfig& rhs) { // prevent implicit convertion of cmpMask
this->_constant = rhs._constant; PortConfig(MemoryDescPtr desc,
this->_inPlacePort = rhs._inPlacePort; int cmpMask,
if (rhs._desc) { int inPlacePort = -1,
this->_desc = rhs._desc; bool isConstant = false) = delete;
}
return *this; PortConfig(const PortConfig& rhs) = default;
}
PortConfig& operator=(const PortConfig& rhs) = default;
PortConfig(PortConfig&& rhs) = default; PortConfig(PortConfig&& rhs) = default;
PortConfig& operator=(PortConfig&& rhs) = default; PortConfig& operator=(PortConfig&& rhs) = default;
@ -124,29 +126,42 @@ public:
return _desc->getMemDesc(); return _desc->getMemDesc();
} }
void setMemDesc(MemoryDescPtr desc) {
if (desc->getType() & Blocked) {
setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), BLOCKED_DESC_FULL_MASK);
} else {
_desc = std::make_shared<PortDescGeneric>(desc);
}
}
void setMemDesc(BlockedMemoryDescPtr desc, BlockedMemoryDesc::CmpMask cmpMask) {
_desc = std::make_shared<PortDescBlocked>(desc, cmpMask);
}
PortDescBasePtr getPortDesc() const { PortDescBasePtr getPortDesc() const {
return _desc; return _desc;
} }
void setMemDesc(MemoryDescPtr desc) {
_desc = createPortDesc(desc, BlockedMemoryDesc::FULL_MASK);
}
void setMemDesc(BlockedMemoryDescPtr desc, BlockedMemoryDesc::CmpMask cmpMask) {
_desc = createPortDesc(desc, cmpMask);
}
private: private:
bool _constant = false; PortDescBasePtr createPortDesc(MemoryDescPtr desc, BlockedMemoryDesc::CmpMask cmpMask) {
int _inPlacePort = -1; if (desc->getType() & Blocked)
return createPortDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(desc), cmpMask);
return std::make_shared<PortDescGeneric>(desc);
}
PortDescBasePtr createPortDesc(BlockedMemoryDescPtr desc, BlockedMemoryDesc::CmpMask cmpMask) {
return std::make_shared<PortDescBlocked>(desc, cmpMask);
}
PortDescBasePtr _desc; PortDescBasePtr _desc;
int _inPlacePort = -1;
bool _constant = false;
}; };
struct NodeConfig { struct NodeConfig {
NodeConfig() = default;
NodeConfig(std::vector<PortConfig> inConfs, std::vector<PortConfig> outConfs)
: inConfs(std::move(inConfs)), outConfs(std::move(outConfs))
{}
std::vector<PortConfig> inConfs; std::vector<PortConfig> inConfs;
std::vector<PortConfig> outConfs; std::vector<PortConfig> outConfs;
}; };

View File

@ -46,8 +46,6 @@ NonZero::NonZero(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CP
} }
void NonZero::getSupportedDescriptors() { void NonZero::getSupportedDescriptors() {
if (!descs.empty())
return;
if (getParentEdges().size() != 1) if (getParentEdges().size() != 1)
IE_THROW() << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size(); IE_THROW() << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size();
if (!getChildEdges().size()) if (!getChildEdges().size())

View File

@ -16,6 +16,7 @@
#include <utils/general_utils.h> #include <utils/general_utils.h>
#include <memory_desc/cpu_memory_desc_utils.h> #include <memory_desc/cpu_memory_desc_utils.h>
#include "memory_desc/dnnl_blocked_memory_desc.h" #include "memory_desc/dnnl_blocked_memory_desc.h"
#include "nodes/node_config.h"
#include <common/primitive_hashing_utils.hpp> #include <common/primitive_hashing_utils.hpp>
// to access and change C pooling primitive desc internal padding field // to access and change C pooling primitive desc internal padding field
@ -576,6 +577,8 @@ void Pooling::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,
const auto& out_candidate = dnnlOutDesc.getDnnlDesc(); const auto& out_candidate = dnnlOutDesc.getDnnlDesc();
auto desc = createDescriptorInternal(in_candidate, out_candidate, getPoolingAlgorithm()); auto desc = createDescriptorInternal(in_candidate, out_candidate, getPoolingAlgorithm());
if (desc)
descs.emplace_back(desc); descs.emplace_back(desc);
} }
@ -583,9 +586,6 @@ void Pooling::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty()) if (!supportedPrimitiveDescriptors.empty())
return; return;
dnnl::primitive_attr attr;
setPostOps(attr);
if (useACL) { if (useACL) {
auto& creatorsMap = BlockedDescCreator::getCommonCreators(); auto& creatorsMap = BlockedDescCreator::getCommonCreators();
auto pushDesc = [&](LayoutType format) { auto pushDesc = [&](LayoutType format) {
@ -599,65 +599,74 @@ void Pooling::initSupportedPrimitiveDescriptors() {
creatorsMap.at(format)->createSharedDesc(getOriginalOutputPrecisionAtPort(0), getOutputShapeAtPort(0))); creatorsMap.at(format)->createSharedDesc(getOriginalOutputPrecisionAtPort(0), getOutputShapeAtPort(0)));
std::vector<MemoryDescPtr> srcMemoryDescs; std::vector<MemoryDescPtr> srcMemoryDescs;
for (size_t i = 0; i < config.inConfs.size(); i++) { for (const auto& inConf : config.inConfs) {
srcMemoryDescs.push_back(config.inConfs[i].getMemDesc()); srcMemoryDescs.push_back(inConf.getMemDesc());
} }
std::vector<MemoryDescPtr> dstMemoryDescs; std::vector<MemoryDescPtr> dstMemoryDescs;
for (size_t i = 0; i < config.outConfs.size(); i++) { for (const auto& outConf : config.outConfs) {
dstMemoryDescs.push_back(config.outConfs[i].getMemDesc()); dstMemoryDescs.push_back(outConf.getMemDesc());
} }
auto factory = std::make_shared<PoolingExecutorFactory>( auto factory = std::make_shared<PoolingExecutorFactory>(
poolingAttrs, poolingAttrs,
srcMemoryDescs, srcMemoryDescs,
dstMemoryDescs, dstMemoryDescs,
std::make_shared<ExecutorContext>(context, getPrimitivesPriority())); std::make_shared<ExecutorContext>(context, getImplPriority()));
supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::undef, factory); supportedPrimitiveDescriptors.emplace_back(config, impl_desc_type::undef, factory);
}; };
pushDesc(LayoutType::ncsp); pushDesc(LayoutType::ncsp);
} else {
for (auto& desc : descs) {
auto itpd = desc;
while (static_cast<bool>(itpd)) { return;
NodeConfig config; }
auto addSupportedPrimitiveDescriptor = [&](const dnnl::primitive_desc& prim_desc) {
std::vector<PortConfig> inConfs, outConfs;
const int inPlaceOutPort = canBeInPlace() ? 0 : -1;
for (size_t i = 0; i < descInputNumbers(); i++) { for (size_t i = 0; i < descInputNumbers(); i++) {
PortConfig dataConfig; auto desc = getSrcMemDesc(prim_desc, i);
dataConfig.inPlace(-1); inConfs.emplace_back(desc);
dataConfig.constant(false);
dataConfig.setMemDesc(getSrcMemDesc(itpd, i));
config.inConfs.push_back(dataConfig);
} }
for (size_t i = 0; i < descOutputNumbers(); i++) { for (size_t i = 0; i < descOutputNumbers(); i++) {
PortConfig dataConfig; auto desc = getDstMemDesc(prim_desc, i);
dataConfig.inPlace(canBeInPlace() ? 0 : -1); // PortConfig in{desc, inPlaceOutPort};
dataConfig.constant(false); outConfs.emplace_back(desc, BlockedMemoryDesc::FULL_MASK, inPlaceOutPort);
dataConfig.setMemDesc(getDstMemDesc(itpd, i));
config.outConfs.push_back(dataConfig);
} }
// CPU plugin doesn't support second output of MaxPool-8, but anyway we should have out config for second port as stub // CPU plugin doesn't support second output of MaxPool-8, but anyway we should have out config for second port as stub
if (isMaxPool8) { if (isMaxPool8) {
auto& creatorsMap = BlockedDescCreator::getCommonCreators(); const auto& creatorsMap = BlockedDescCreator::getCommonCreators();
PortConfig dataConfig; const auto outputPrecision = outConfs.front().getMemDesc()->getPrecision();
dataConfig.inPlace(-1); auto desc = creatorsMap.at(LayoutType::ncsp)->createSharedDesc(outputPrecision, getOutputShapeAtPort(1));
dataConfig.constant(false);
dataConfig.setMemDesc(creatorsMap.at(LayoutType::ncsp)->createSharedDesc(config.outConfs.front().getMemDesc()->getPrecision(),
getOutputShapeAtPort(1)));
config.outConfs.push_back(dataConfig); outConfs.emplace_back(desc);
} }
impl_desc_type impl_type = parse_impl_name(itpd.impl_info_str()); const NodeConfig config(inConfs, outConfs);
const impl_desc_type impl_type = parse_impl_name(prim_desc.impl_info_str());
supportedPrimitiveDescriptors.emplace_back(config, impl_type); supportedPrimitiveDescriptors.emplace_back(config, impl_type);
if (!itpd.next_impl()) };
break;
} for (auto& desc : descs) {
} auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get()));
const bool first_match = customImplPriorities.empty();
DnnlExtensionUtils::for_each_implementation(desc,
first_match,
[&](impl_desc_type implType) {
return contains(getImplPriority(), implType);
},
[&](dnnl::primitive_desc& desc) {
addSupportedPrimitiveDescriptor(desc);
});
// fallback. if none of the primitive types is present in the priority list just add first implementation
// @todo this fallback is not necessary if primitive priority list is filled correctly
if (supportedPrimitiveDescriptors.empty())
addSupportedPrimitiveDescriptor(first_desc);
} }
} }

View File

@ -1771,9 +1771,6 @@ Reduce::Reduce(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CPtr
} }
void Reduce::getSupportedDescriptors() { void Reduce::getSupportedDescriptors() {
if (!descs.empty())
return;
if (getParentEdges().size() != 2) if (getParentEdges().size() != 2)
IE_THROW() << errorPrefix << " gets incorrect number of input edges!"; IE_THROW() << errorPrefix << " gets incorrect number of input edges!";
if (getChildEdges().empty()) if (getChildEdges().empty())
@ -1858,7 +1855,7 @@ void Reduce::initSupportedPrimitiveDescriptors() {
} }
auto factory = std::make_shared<ReduceExecutorFactory>(reduceAttrs, srcMemoryDescs, dstMemoryDescs, auto factory = std::make_shared<ReduceExecutorFactory>(reduceAttrs, srcMemoryDescs, dstMemoryDescs,
std::make_shared<ExecutorContext>(context, getPrimitivesPriority())); std::make_shared<ExecutorContext>(context, getImplPriority()));
if (!factory->isEmpty()) { if (!factory->isEmpty()) {
supportedPrimitiveDescriptors.push_back({config, impl_type, factory}); supportedPrimitiveDescriptors.push_back({config, impl_type, factory});
} }

View File

@ -253,9 +253,10 @@ void Reorder::createReorderPrimitive(const dnnl::memory::desc& srcDesc,
#endif #endif
} }
const std::vector<impl_desc_type>& Reorder::getPrimitivesPriority() { const std::vector<impl_desc_type>& Reorder::getDefaultImplPriority() {
implPriorities = {impl_desc_type::reorder}; static const std::vector<impl_desc_type> priorities = {impl_desc_type::reorder};
return implPriorities;
return priorities;
} }
bool Reorder::created() const { bool Reorder::created() const {

View File

@ -24,7 +24,7 @@ public:
void initSupportedPrimitiveDescriptors() override; void initSupportedPrimitiveDescriptors() override;
void execute(dnnl::stream strm) override; void execute(dnnl::stream strm) override;
bool created() const override; bool created() const override;
const std::vector<impl_desc_type>& getPrimitivesPriority() override; const std::vector<impl_desc_type>& getDefaultImplPriority() override;
bool isExecutable() const override; bool isExecutable() const override;

View File

@ -956,7 +956,7 @@ void RNN::fillDescs() {
wDescs, wDescs,
*attr); *attr);
descs.push_back(desc); descs.emplace_back(desc);
} }
void RNN::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc, void RNN::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,
@ -1109,11 +1109,13 @@ void RNN::prepareParams() {
primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive(); primArgs[DNNL_ARG_SCRATCHPAD] = scratchpadMem->GetPrimitive();
} }
std::shared_ptr<MemoryDesc> RNN::getSrcMemDesc(dnnl::primitive_desc_iterator& primitive_desc_it, size_t idx) { std::shared_ptr<MemoryDesc> RNN::getSrcMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const {
(void) prim_desc;
return supportedPrimitiveDescriptors[0].getConfig().inConfs[idx].getMemDesc(); return supportedPrimitiveDescriptors[0].getConfig().inConfs[idx].getMemDesc();
} }
std::shared_ptr<MemoryDesc> RNN::getDstMemDesc(dnnl::primitive_desc_iterator& primitive_desc_it, size_t idx) { std::shared_ptr<MemoryDesc> RNN::getDstMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const {
(void) prim_desc;
return supportedPrimitiveDescriptors[0].getConfig().outConfs[idx].getMemDesc(); return supportedPrimitiveDescriptors[0].getConfig().outConfs[idx].getMemDesc();
} }

View File

@ -25,8 +25,8 @@ public:
static bool isCell(const std::shared_ptr<const ngraph::Node>& op); static bool isCell(const std::shared_ptr<const ngraph::Node>& op);
static bool testNativeOrder(const std::shared_ptr<const ngraph::Node>& op); static bool testNativeOrder(const std::shared_ptr<const ngraph::Node>& op);
void getSupportedDescriptors() override; void getSupportedDescriptors() override;
std::shared_ptr<MemoryDesc> getSrcMemDesc(dnnl::primitive_desc_iterator& primitive_desc_it, size_t idx) override; std::shared_ptr<MemoryDesc> getSrcMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const override;
std::shared_ptr<MemoryDesc> getDstMemDesc(dnnl::primitive_desc_iterator& primitive_desc_it, size_t idx) override; std::shared_ptr<MemoryDesc> getDstMemDesc(const dnnl::primitive_desc& prim_desc, size_t idx) const override;
bool created() const override; bool created() const override;
void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc, void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc,
const std::vector<MemoryDescPtr>& outputDesc) override; const std::vector<MemoryDescPtr>& outputDesc) override;

View File

@ -705,9 +705,6 @@ ROIAlign::ROIAlign(const std::shared_ptr<ngraph::Node>& op, const GraphContext::
} }
void ROIAlign::getSupportedDescriptors() { void ROIAlign::getSupportedDescriptors() {
if (!descs.empty())
return;
if (getParentEdges().size() != 3) if (getParentEdges().size() != 3)
IE_THROW() << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size(); IE_THROW() << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size();
if (getChildEdges().empty()) if (getChildEdges().empty())

View File

@ -407,9 +407,6 @@ ROIPooling::ROIPooling(const std::shared_ptr<ngraph::Node>& op, const GraphConte
} }
void ROIPooling::getSupportedDescriptors() { void ROIPooling::getSupportedDescriptors() {
if (!descs.empty())
return;
if (getParentEdges().size() != 2) if (getParentEdges().size() != 2)
IE_THROW() << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size(); IE_THROW() << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size();
if (getChildEdges().empty()) if (getChildEdges().empty())

View File

@ -68,8 +68,6 @@ ShapeOf::ShapeOf(const std::shared_ptr<ngraph::Node>& op, const GraphContext::CP
} }
void ShapeOf::getSupportedDescriptors() { void ShapeOf::getSupportedDescriptors() {
if (!descs.empty())
return;
if (getParentEdges().size() != 1) if (getParentEdges().size() != 1)
IE_THROW() << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size(); IE_THROW() << errorPrefix << "has incorrect number of input edges: " << getParentEdges().size();
if (getChildEdges().empty()) if (getChildEdges().empty())

View File

@ -124,7 +124,7 @@ void SoftMax::initOptimalPrimitiveDescriptor() {
auto config = selected_pd->getConfig(); auto config = selected_pd->getConfig();
if (isDynamicNode()) { if (isDynamicNode()) {
auto outMemDesc = config.outConfs[0].getMemDesc(); auto outMemDesc = config.outConfs[0].getMemDesc();
config.outConfs[0].setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(outMemDesc), BLOCKED_DESC_FULL_MASK); config.outConfs[0].setMemDesc(std::dynamic_pointer_cast<BlockedMemoryDesc>(outMemDesc), BlockedMemoryDesc::FULL_MASK);
} else { } else {
if (config.inConfs.size() != 1 || config.outConfs.size() != 1 || if (config.inConfs.size() != 1 || config.outConfs.size() != 1 ||
(config.inConfs[0].getMemDesc()->isDefined() && (config.inConfs[0].getMemDesc()->isDefined() &&
@ -155,7 +155,8 @@ void SoftMax::createDescriptor(const std::vector<MemoryDescPtr> &inputDesc,
*attr, *attr,
true); true);
descs.push_back(desc); if (desc)
descs.emplace_back(desc);
} }
void SoftMax::prepareParams() { void SoftMax::prepareParams() {

View File

@ -178,7 +178,7 @@ void Split::initSupportedPrimitiveDescriptors() {
SizeVector strides(numOfDim); SizeVector strides(numOfDim);
strides.back() = 1lu; strides.back() = 1lu;
size_t offset = Shape::UNDEFINED_DIM; size_t offset = Shape::UNDEFINED_DIM;
BlockedMemoryDesc::CmpMask mask = BLOCKED_DESC_SKIP_OFFSET_MASK; // accepts any offset BlockedMemoryDesc::CmpMask mask = BlockedMemoryDesc::SKIP_OFFSET_MASK; // accepts any offset
for (size_t i = 2; i <= numOfDim; i++) { for (size_t i = 2; i <= numOfDim; i++) {
if (numOfDim - i < axis) { if (numOfDim - i < axis) {
@ -363,14 +363,16 @@ void Split::initOptimalPrimitiveDescriptor() {
auto outBlockingDesc = oldDesc->as<BlockedMemoryDesc>(); auto outBlockingDesc = oldDesc->as<BlockedMemoryDesc>();
const auto& shape = outBlockingDesc->getShape(); const auto& shape = outBlockingDesc->getShape();
const auto& blkDims = outBlockingDesc->getBlockDims(); const auto& blkDims = outBlockingDesc->getBlockDims();
config.outConfs[i].setMemDesc(std::make_shared<CpuBlockedMemoryDesc>(outBlockingDesc->getPrecision(), config.outConfs[i].setMemDesc(std::make_shared<CpuBlockedMemoryDesc>(
outBlockingDesc->getPrecision(),
shape, shape,
blkDims, blkDims,
outBlockingDesc->getOrder(), outBlockingDesc->getOrder(),
firstInBlockingDesc->getOffsetPadding() + offset, firstInBlockingDesc->getOffsetPadding() + offset,
firstInBlockingDesc->getOffsetPaddingToData(), firstInBlockingDesc->getOffsetPaddingToData(),
(shape.hasZeroDims() ? VectorDims(blkDims.size(), 0) : (shape.hasZeroDims() ? VectorDims(blkDims.size(), 0) :
firstInBlockingDesc->getStrides())), BLOCKED_DESC_FULL_MASK); firstInBlockingDesc->getStrides())),
BlockedMemoryDesc::FULL_MASK);
size_t axisSize = 1; size_t axisSize = 1;
for (size_t j = axis; j < outBlockingDesc->getBlockDims().size(); j++) { for (size_t j = axis; j < outBlockingDesc->getBlockDims().size(); j++) {
@ -398,7 +400,7 @@ void Split::selectOptimalPrimitiveDescriptor() {
// Enforce the reference implementation for the planar layout if the implementation is in the impl priorities list. // Enforce the reference implementation for the planar layout if the implementation is in the impl priorities list.
// This is needed mostly for the testing purposes, since for the planar layout Split works always in place, we need to enforce // This is needed mostly for the testing purposes, since for the planar layout Split works always in place, we need to enforce
// the reference implementation when it is selected in a test to test that piece of code. // the reference implementation when it is selected in a test to test that piece of code.
if (!implPriorities.empty() && implPriorities[0] == impl_desc_type::ref) { if (!customImplPriorities.empty() && customImplPriorities[0] == impl_desc_type::ref) {
for (size_t i = 0; i < supportedPrimitiveDescriptors.size(); ++i) { for (size_t i = 0; i < supportedPrimitiveDescriptors.size(); ++i) {
auto& pd = supportedPrimitiveDescriptors[i]; auto& pd = supportedPrimitiveDescriptors[i];
if (pd.getConfig().inConfs[0].getMemDesc()->hasLayoutType(LayoutType::ncsp) && if (pd.getConfig().inConfs[0].getMemDesc()->hasLayoutType(LayoutType::ncsp) &&

View File

@ -191,7 +191,7 @@ void Snippet::initSupportedPrimitiveDescriptors() {
const auto equalPrecisions = getOriginalOutputPrecisions().size() == 1 && const auto equalPrecisions = getOriginalOutputPrecisions().size() == 1 &&
precision == getOriginalOutputPrecisionAtPort(0); precision == getOriginalOutputPrecisionAtPort(0);
BlockedMemoryDesc::CmpMask inputMask = BLOCKED_DESC_SKIP_OFFSET_MASK; BlockedMemoryDesc::CmpMask inputMask = BlockedMemoryDesc::SKIP_OFFSET_MASK;
PortConfig portConfig; PortConfig portConfig;
portConfig.inPlace((!i && canBeInPlace() && equalPrecisions) ? 0 : -1); portConfig.inPlace((!i && canBeInPlace() && equalPrecisions) ? 0 : -1);
portConfig.constant(false); portConfig.constant(false);
@ -207,7 +207,7 @@ void Snippet::initSupportedPrimitiveDescriptors() {
if (supportedPrecisions.count(precision) == 0) if (supportedPrecisions.count(precision) == 0)
IE_THROW() << "Subgraph node with name `" << getName() << "` doesn't support " << precision << " precision."; IE_THROW() << "Subgraph node with name `" << getName() << "` doesn't support " << precision << " precision.";
BlockedMemoryDesc::CmpMask outputMask = BLOCKED_DESC_SKIP_OFFSET_MASK; BlockedMemoryDesc::CmpMask outputMask = BlockedMemoryDesc::SKIP_OFFSET_MASK;
PortConfig portConfig; PortConfig portConfig;
portConfig.inPlace(-1); portConfig.inPlace(-1);
portConfig.constant(false); portConfig.constant(false);
@ -235,7 +235,7 @@ void Snippet::initSupportedPrimitiveDescriptors() {
} }
void Snippet::selectOptimalPrimitiveDescriptor() { void Snippet::selectOptimalPrimitiveDescriptor() {
selectPreferPrimitiveDescriptor(getPrimitivesPriority(), true); selectPreferPrimitiveDescriptor(getImplPriority(), true);
} }
InferenceEngine::Precision Snippet::getRuntimePrecision() const { InferenceEngine::Precision Snippet::getRuntimePrecision() const {
std::vector<InferenceEngine::Precision> inputPrecisions; std::vector<InferenceEngine::Precision> inputPrecisions;

View File

@ -3,6 +3,7 @@
// //
#include "iml_type_mapper.h" #include "iml_type_mapper.h"
#include <algorithm>
namespace ov { namespace ov {
namespace intel_cpu { namespace intel_cpu {
@ -122,5 +123,9 @@ const char* impl_type_to_string(impl_desc_type type) {
return "unknown"; return "unknown";
} }
bool contains(const std::vector<impl_desc_type>& priorities, const impl_desc_type impl_type_str) {
return std::find(priorities.begin(), priorities.end(), impl_type_str) != priorities.end();
}
} // namespace intel_cpu } // namespace intel_cpu
} // namespace ov } // namespace ov

View File

@ -5,6 +5,7 @@
#pragma once #pragma once
#include <string> #include <string>
#include <vector>
namespace ov { namespace ov {
namespace intel_cpu { namespace intel_cpu {
@ -101,6 +102,7 @@ enum impl_desc_type {
const char * impl_type_to_string(impl_desc_type type); const char * impl_type_to_string(impl_desc_type type);
impl_desc_type parse_impl_name(std::string impl_desc_name); impl_desc_type parse_impl_name(std::string impl_desc_name);
bool contains(const std::vector<impl_desc_type>& priorities, const impl_desc_type impl_type_str);
} // namespace intel_cpu } // namespace intel_cpu
} // namespace ov } // namespace ov

View File

@ -9,6 +9,7 @@
#include <bitset> #include <bitset>
#include <unordered_map> #include <unordered_map>
#include <utility>
namespace ov { namespace ov {
namespace intel_cpu { namespace intel_cpu {
@ -79,13 +80,13 @@ public:
}; };
struct PropertyGroup { struct PropertyGroup {
virtual std::vector<PropertySetterPtr> getPropertySetters(void) = 0; virtual std::vector<PropertySetterPtr> getPropertySetters() = 0;
void parseAndSet(const std::string& str) { void parseAndSet(const std::string& str) {
const auto& options = ov::util::split(str, ' '); const auto& options = ov::util::split(str, ' ');
const auto& propertySetters = getPropertySetters(); const auto& propertySetters = getPropertySetters();
bool failed = false; bool failed = false;
auto getHelp = [propertySetters] (void) { auto getHelp = [propertySetters]() {
std::string help; std::string help;
for (const auto& property : propertySetters) for (const auto& property : propertySetters)
help.append('\t' + property->getPropertyName() + "=<" + property->getPropertyValueDescription() + ">\n"); help.append('\t' + property->getPropertyName() + "=<" + property->getPropertyValueDescription() + ">\n");
@ -118,7 +119,7 @@ public:
struct : PropertyGroup { struct : PropertyGroup {
TransformationFilter transformations; TransformationFilter transformations;
std::vector<PropertySetterPtr> getPropertySetters(void) override { std::vector<PropertySetterPtr> getPropertySetters() override {
return { transformations.getPropertySetter() }; return { transformations.getPropertySetter() };
} }
} disable; } disable;
@ -128,7 +129,7 @@ public:
IrFormatFilter format = { 1 << IrFormatFilter::Xml }; IrFormatFilter format = { 1 << IrFormatFilter::Xml };
TransformationFilter transformations; TransformationFilter transformations;
std::vector<PropertySetterPtr> getPropertySetters(void) override { std::vector<PropertySetterPtr> getPropertySetters() override {
return { PropertySetterPtr(new StringPropertySetter("dir", dir, "path to dumped IRs")), return { PropertySetterPtr(new StringPropertySetter("dir", dir, "path to dumped IRs")),
format.getPropertySetter(), format.getPropertySetter(),
transformations.getPropertySetter() }; transformations.getPropertySetter() };
@ -138,23 +139,29 @@ public:
private: private:
struct PropertySetter { struct PropertySetter {
virtual bool parseAndSet(const std::string& str) = 0; virtual bool parseAndSet(const std::string& str) = 0;
virtual std::string getPropertyValueDescription(void) const = 0; virtual std::string getPropertyValueDescription() const = 0;
PropertySetter(const std::string&& name) : propertyName(name) {} PropertySetter(std::string name) : propertyName(std::move(name)) {}
const std::string& getPropertyName(void) const { return propertyName; }
virtual ~PropertySetter() = default;
const std::string& getPropertyName() const { return propertyName; }
private: private:
const std::string propertyName; const std::string propertyName;
}; };
struct StringPropertySetter : PropertySetter { struct StringPropertySetter : PropertySetter {
StringPropertySetter(const std::string&& name, std::string& ref, const std::string&& valueDescription) StringPropertySetter(const std::string& name, std::string& ref, const std::string&& valueDescription)
: PropertySetter(std::move(name)), property(ref), propertyValueDescription(valueDescription) {} : PropertySetter(name), property(ref), propertyValueDescription(valueDescription) {}
~StringPropertySetter() override = default;
bool parseAndSet(const std::string& str) override { bool parseAndSet(const std::string& str) override {
property = str; property = str;
return true; return true;
} }
std::string getPropertyValueDescription(void) const override { return propertyValueDescription; } std::string getPropertyValueDescription() const override { return propertyValueDescription; }
private: private:
std::string& property; std::string& property;
@ -168,8 +175,11 @@ private:
std::vector<size_t> bits; std::vector<size_t> bits;
}; };
BitsetFilterPropertySetter(const std::string&& name, std::bitset<NumOfBits>& ref, const std::vector<Token>&& tokens) BitsetFilterPropertySetter(const std::string& name, std::bitset<NumOfBits>& ref, const std::vector<Token>&& tokens)
: PropertySetter(std::move(name)), property(ref), propertyTokens(tokens) {} : PropertySetter(name), property(ref), propertyTokens(tokens) {}
~BitsetFilterPropertySetter() override = default;
bool parseAndSet(const std::string& str) override { bool parseAndSet(const std::string& str) override {
const auto& tokens = str.empty() ? const auto& tokens = str.empty() ?
std::vector<std::string>{"all"} : ov::util::split(ov::util::to_lower(str), ','); std::vector<std::string>{"all"} : ov::util::split(ov::util::to_lower(str), ',');
@ -188,7 +198,7 @@ private:
} }
return true; return true;
} }
std::string getPropertyValueDescription(void) const override { std::string getPropertyValueDescription() const override {
std::string supportedTokens = "comma separated filter tokens: "; std::string supportedTokens = "comma separated filter tokens: ";
for (size_t i = 0; i < propertyTokens.size(); i++) { for (size_t i = 0; i < propertyTokens.size(); i++) {
if (i) if (i)

View File

@ -19,7 +19,7 @@ inline std::string getRTInfoValue(const std::map<std::string, ov::Any>& rtInfo,
} }
} }
inline std::string getPrimitivesPriorityValue(const std::shared_ptr<ngraph::Node> &node) { inline std::string getImplPriorityValue(const std::shared_ptr<ngraph::Node> &node) {
const auto &rtInfo = node->get_rt_info(); const auto &rtInfo = node->get_rt_info();
auto it_info = rtInfo.find(ov::PrimitivesPriority::get_type_info_static()); auto it_info = rtInfo.find(ov::PrimitivesPriority::get_type_info_static());

View File

@ -507,6 +507,23 @@ INSTANTIATE_TEST_SUITE_P(smoke_Conv_1D_GEMM_FP32, ConvolutionLayerCPUTest,
::testing::Values(cpuEmptyPluginConfig)), ::testing::Values(cpuEmptyPluginConfig)),
ConvolutionLayerCPUTest::getTestCaseName); ConvolutionLayerCPUTest::getTestCaseName);
// Verify that even if primitive is missed in custom priority list there is still a fallback to the default priority list
const auto conv_gemm_1D_improperPriorityList = CPUSpecificParams{{ncw}, {ncw}, {"unknown"}, "jit_gemm"};
INSTANTIATE_TEST_SUITE_P(smoke_Conv_1D_GEMM_FP32_ImproperPriorityList, ConvolutionLayerCPUTest,
::testing::Combine(
::testing::Combine(
convParams_ExplicitPadding_GEMM_1D,
::testing::Values(ElementType::f32),
::testing::Values(ElementType::undefined),
::testing::Values(ElementType::undefined),
::testing::ValuesIn(inShapesGemm1D),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
::testing::ValuesIn(filterCPUInfoForDevice({conv_gemm_1D})),
::testing::Values(emptyFusingSpec),
::testing::Values(cpuEmptyPluginConfig)),
ConvolutionLayerCPUTest::getTestCaseName);
INSTANTIATE_TEST_SUITE_P(smoke_Conv_1D_GEMM_BF16, ConvolutionLayerCPUTest, INSTANTIATE_TEST_SUITE_P(smoke_Conv_1D_GEMM_BF16, ConvolutionLayerCPUTest,
::testing::Combine( ::testing::Combine(
::testing::Combine( ::testing::Combine(

View File

@ -1914,11 +1914,11 @@ INSTANTIATE_TEST_SUITE_P(smoke_JIT_AVX512_DW_GroupConv, GroupConvolutionLayerCPU
/* ============= brgemm GroupConvolution test, expect fallback to other implementation ============= */ /* ============= brgemm GroupConvolution test, expect fallback to other implementation ============= */
const std::vector<CPUSpecificParams> CPUParams_Fallback_Brgemm_2D = { const std::vector<CPUSpecificParams> CPUParams_Fallback_Brgemm_2D = {
conv_avx512_2D_nspc_brgconv, CPUSpecificParams{{nhwc}, {nhwc}, {/* non-brgconv_avx512 is expected */}, "brgconv_avx512"},
conv_avx512_2D_nspc_brgconv_amx CPUSpecificParams{{nhwc}, {nhwc}, {/* non-brgconv_avx512_amx is expected */}, "brgconv_avx512_amx"},
}; };
const std::vector<CPUSpecificParams> CPUParams_Fallback_Brgemm_1D_Small_Shape = { const std::vector<CPUSpecificParams> CPUParams_Fallback_Brgemm_1D_Small_Shape = {
conv_avx512_1D_nspc_brgconv_amx CPUSpecificParams{{nwc}, {nwc}, {/* non-brgconv_avx512_amx is expected */}, "brgconv_avx512_amx"}
}; };
const std::vector<groupConvLayerCPUTestParamsSet> BRGEMM_EXPECT_FALLBACK_GroupConvTestCases = generateSingleGroupConvCPUTestCases( const std::vector<groupConvLayerCPUTestParamsSet> BRGEMM_EXPECT_FALLBACK_GroupConvTestCases = generateSingleGroupConvCPUTestCases(
// channel <= 16 // channel <= 16

View File

@ -251,7 +251,7 @@ std::vector<CPUSpecificParams> filterSpecificParams_BrgemmAmx() {
std::vector<CPUSpecificParams> filterSpecificParams_Brgconv1x1() { std::vector<CPUSpecificParams> filterSpecificParams_Brgconv1x1() {
std::vector<CPUSpecificParams> specificParams; std::vector<CPUSpecificParams> specificParams;
if (with_cpu_x86_avx512_core()) { if (with_cpu_x86_avx512_core()) {
specificParams.push_back(CPUSpecificParams{{}, {}, {"brgconv_avx512_1x1"}, "brgconv_avx512_1x1"}); specificParams.push_back(CPUSpecificParams{{}, {}, {/* brgconv_avx512_1x1 is not a part of fc impl list */}, "brgconv_avx512_1x1"});
} }
return specificParams; return specificParams;

View File

@ -11,9 +11,9 @@ namespace CPUTestUtils {
const auto conv_ref_2D = CPUSpecificParams{{nchw}, {nchw}, {"ref_any"}, "ref_any"}; const auto conv_ref_2D = CPUSpecificParams{{nchw}, {nchw}, {"ref_any"}, "ref_any"};
const auto conv_ref_3D = CPUSpecificParams{{ncdhw}, {ncdhw}, {"ref_any"}, "ref_any"}; const auto conv_ref_3D = CPUSpecificParams{{ncdhw}, {ncdhw}, {"ref_any"}, "ref_any"};
const auto conv_gemm_1D = CPUSpecificParams{{ncw}, {ncw}, {"gemm_any"}, "jit_gemm"}; const auto conv_gemm_1D = CPUSpecificParams{{ncw}, {ncw}, {"jit_gemm"}, "jit_gemm"};
const auto conv_gemm_2D = CPUSpecificParams{{nchw}, {nchw}, {"gemm_any"}, "jit_gemm"}; const auto conv_gemm_2D = CPUSpecificParams{{nchw}, {nchw}, {"jit_gemm"}, "jit_gemm"};
const auto conv_gemm_3D = CPUSpecificParams{{ncdhw}, {ncdhw}, {"gemm_any"}, "jit_gemm"}; const auto conv_gemm_3D = CPUSpecificParams{{ncdhw}, {ncdhw}, {"jit_gemm"}, "jit_gemm"};
const auto conv_gemm_1D_nspc = CPUSpecificParams{{nwc}, {nwc}, {"jit_gemm"}, "jit_gemm"}; const auto conv_gemm_1D_nspc = CPUSpecificParams{{nwc}, {nwc}, {"jit_gemm"}, "jit_gemm"};
const auto conv_gemm_2D_nspc = CPUSpecificParams{{nhwc}, {nhwc}, {"jit_gemm"}, "jit_gemm"}; const auto conv_gemm_2D_nspc = CPUSpecificParams{{nhwc}, {nhwc}, {"jit_gemm"}, "jit_gemm"};