[CPU] Add early throw for filtering primitive descriptors (#17065)

Otherwise the plugin will throw later when trying to init optimal
primitive descriptor (when there is actually no descriptors at all)
This commit is contained in:
Egor Duplenskii
2023-06-19 14:16:02 +02:00
committed by GitHub
parent e54bd6ab1b
commit 6143a4fa42
6 changed files with 7 additions and 85 deletions

View File

@@ -335,8 +335,9 @@ void Node::selectPreferPrimitiveDescriptor(const std::vector<impl_desc_type>& pr
}
}
if (getSupportedPrimitiveDescriptors().empty())
IE_THROW() << "Supported primitive descriptors list is empty for node: " << getName();
IE_ASSERT(!getSupportedPrimitiveDescriptors().empty()) <<
"Supported primitive descriptors list is empty for node: " << getName() << " type: " << NameFromType(getType());
// fallback. If there are no primitives from priority list just select a first
selectPrimitiveDescriptorByIndex(0);
}
@@ -711,6 +712,9 @@ void Node::filterSupportedPrimitiveDescriptors() {
supportedPrimitiveDescriptors.erase(
std::remove_if(supportedPrimitiveDescriptors.begin(), supportedPrimitiveDescriptors.end(), isNotSuitableDesc),
supportedPrimitiveDescriptors.end());
IE_ASSERT(!supportedPrimitiveDescriptors.empty()) << getName() << " type: " << NameFromType(getType()) <<
" No supported primitive descriptors matched the provided input / output memory format filters.";
}
void Node::initDescriptor(const NodeConfig& config) {

View File

@@ -369,7 +369,7 @@ public:
* @brief Filters supportedPrimitiveDescriptors according to the input layouts specified in inputMemoryFormatsFilter
* and output layouts specified in outputMemoryFormatsFilter
*/
virtual void filterSupportedPrimitiveDescriptors();
void filterSupportedPrimitiveDescriptors();
virtual void createPrimitive();

View File

@@ -1051,43 +1051,6 @@ void Convolution::initDescriptor(const NodeConfig& config) {
selectedPD->setConfig(updatedConfig);
}
void Convolution::filterSupportedPrimitiveDescriptors() {
Node::filterSupportedPrimitiveDescriptors();
// We also need to filter descs in Convolution node
filterSupportedDescriptors();
}
void Convolution::filterSupportedDescriptors() {
if (inputMemoryFormatsFilter.empty() && outputMemoryFormatsFilter.empty())
return;
if (inputMemoryFormatsFilter.size() > 1 || outputMemoryFormatsFilter.size() > 1)
IE_THROW() << "Incorrect number of input or output memory formats for Convolution node";
auto isNotSuitableDesc = [&](const dnnl::primitive_desc& desc) {
if (!inputMemoryFormatsFilter.empty()) {
auto src_tdesc = DnnlExtensionUtils::makeDescriptor(desc.src_desc());
if (src_tdesc->isSame(inputMemoryFormatsFilter[0])) {
DEBUG_LOG(getName(), " input memory format filter: ", inputMemoryFormatsFilter[0],
" not matched. Erase desc from the list of dnnl primitive descriptors: ", desc);
return true;
}
}
if (!outputMemoryFormatsFilter.empty()) {
auto dst_tdesc = DnnlExtensionUtils::makeDescriptor(desc.dst_desc());
if (dst_tdesc->isSame(outputMemoryFormatsFilter[0])) {
DEBUG_LOG(getName(), " Output memory format filter: ", outputMemoryFormatsFilter[0],
" not matched. Erase desc from the list of dnnl primitive descriptors: ", desc);
return true;
}
}
return false;
};
descs.erase(std::remove_if(descs.begin(), descs.end(), isNotSuitableDesc), descs.end());
}
std::shared_ptr<MemoryDesc> Convolution::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
if (idx == 1) {
// report original plain layout for weight since it needs to be reordered dynamically at runtime

View File

@@ -28,7 +28,6 @@ public:
void initDescriptor(const NodeConfig& config) override;
void selectOptimalPrimitiveDescriptor() override;
void initSupportedPrimitiveDescriptors() override;
void filterSupportedPrimitiveDescriptors() override;
bool created() const override;
bool canBeInPlace() const override {
return false;

View File

@@ -543,48 +543,6 @@ void Deconvolution::setPostOps(dnnl::primitive_attr& attr, const VectorDims& dim
attr.set_post_ops(ops);
}
void Deconvolution::filterSupportedPrimitiveDescriptors() {
Node::filterSupportedPrimitiveDescriptors();
filterSupportedDescriptors();
}
void Deconvolution::filterSupportedDescriptors() {
if (inputMemoryFormatsFilter.empty() && outputMemoryFormatsFilter.empty())
return;
if (inputMemoryFormatsFilter.size() > 1 || outputMemoryFormatsFilter.size() > 1)
IE_THROW() << "Incorrect number of input or output memory formats for Deconvolution node";
auto isNotSuitableDesc = [&](const dnnl::primitive_desc& desc) {
if (!inputMemoryFormatsFilter.empty()) {
auto src_tdesc = isInt8 ? DnnlExtensionUtils::makeDescriptor(desc.src_desc()) :
DnnlExtensionUtils::makeDescriptor(desc.diff_src_desc());
if (!src_tdesc->isSame(inputMemoryFormatsFilter[0])) {
DEBUG_LOG(getName(), " input memory format filter: ", inputMemoryFormatsFilter[0],
"not matched. Erase desc from the list of dnnl primitive descriptors: ", desc);
return true;
}
}
if (!outputMemoryFormatsFilter.empty()) {
auto dst_tdesc = isInt8 ? DnnlExtensionUtils::makeDescriptor(desc.dst_desc()) :
DnnlExtensionUtils::makeDescriptor(desc.diff_dst_desc());
if (!dst_tdesc->isSame(outputMemoryFormatsFilter[0])) {
DEBUG_LOG(getName(), " Output memory format filter: ", outputMemoryFormatsFilter[0],
" not matched. Erase desc from the list of dnnl primitive descriptors: ", desc);
return true;
}
}
return false;
};
descs.erase(std::remove_if(descs.begin(), descs.end(), isNotSuitableDesc), descs.end());
}
bool Deconvolution::created() const {
return getType() == Type::Deconvolution;
}

View File

@@ -23,8 +23,6 @@ public:
void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc,
const std::vector<MemoryDescPtr>& outputDesc) override;
void createPrimitive() override;
void filterSupportedPrimitiveDescriptors() override;
void filterSupportedDescriptors();
bool created() const override;
bool canBeInPlace() const override {
return false;