[CPU] Add early throw for filtering primitive descriptors (#17065)
Otherwise the plugin will throw later when trying to init optimal primitive descriptor (when there is actually no descriptors at all)
This commit is contained in:
@@ -335,8 +335,9 @@ void Node::selectPreferPrimitiveDescriptor(const std::vector<impl_desc_type>& pr
|
||||
}
|
||||
}
|
||||
|
||||
if (getSupportedPrimitiveDescriptors().empty())
|
||||
IE_THROW() << "Supported primitive descriptors list is empty for node: " << getName();
|
||||
IE_ASSERT(!getSupportedPrimitiveDescriptors().empty()) <<
|
||||
"Supported primitive descriptors list is empty for node: " << getName() << " type: " << NameFromType(getType());
|
||||
|
||||
// fallback. If there are no primitives from priority list just select a first
|
||||
selectPrimitiveDescriptorByIndex(0);
|
||||
}
|
||||
@@ -711,6 +712,9 @@ void Node::filterSupportedPrimitiveDescriptors() {
|
||||
supportedPrimitiveDescriptors.erase(
|
||||
std::remove_if(supportedPrimitiveDescriptors.begin(), supportedPrimitiveDescriptors.end(), isNotSuitableDesc),
|
||||
supportedPrimitiveDescriptors.end());
|
||||
|
||||
IE_ASSERT(!supportedPrimitiveDescriptors.empty()) << getName() << " type: " << NameFromType(getType()) <<
|
||||
" No supported primitive descriptors matched the provided input / output memory format filters.";
|
||||
}
|
||||
|
||||
void Node::initDescriptor(const NodeConfig& config) {
|
||||
|
||||
@@ -369,7 +369,7 @@ public:
|
||||
* @brief Filters supportedPrimitiveDescriptors according to the input layouts specified in inputMemoryFormatsFilter
|
||||
* and output layouts specified in outputMemoryFormatsFilter
|
||||
*/
|
||||
virtual void filterSupportedPrimitiveDescriptors();
|
||||
void filterSupportedPrimitiveDescriptors();
|
||||
|
||||
virtual void createPrimitive();
|
||||
|
||||
|
||||
@@ -1051,43 +1051,6 @@ void Convolution::initDescriptor(const NodeConfig& config) {
|
||||
selectedPD->setConfig(updatedConfig);
|
||||
}
|
||||
|
||||
void Convolution::filterSupportedPrimitiveDescriptors() {
|
||||
Node::filterSupportedPrimitiveDescriptors();
|
||||
// We also need to filter descs in Convolution node
|
||||
filterSupportedDescriptors();
|
||||
}
|
||||
|
||||
void Convolution::filterSupportedDescriptors() {
|
||||
if (inputMemoryFormatsFilter.empty() && outputMemoryFormatsFilter.empty())
|
||||
return;
|
||||
|
||||
if (inputMemoryFormatsFilter.size() > 1 || outputMemoryFormatsFilter.size() > 1)
|
||||
IE_THROW() << "Incorrect number of input or output memory formats for Convolution node";
|
||||
|
||||
auto isNotSuitableDesc = [&](const dnnl::primitive_desc& desc) {
|
||||
if (!inputMemoryFormatsFilter.empty()) {
|
||||
auto src_tdesc = DnnlExtensionUtils::makeDescriptor(desc.src_desc());
|
||||
if (src_tdesc->isSame(inputMemoryFormatsFilter[0])) {
|
||||
DEBUG_LOG(getName(), " input memory format filter: ", inputMemoryFormatsFilter[0],
|
||||
" not matched. Erase desc from the list of dnnl primitive descriptors: ", desc);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
if (!outputMemoryFormatsFilter.empty()) {
|
||||
auto dst_tdesc = DnnlExtensionUtils::makeDescriptor(desc.dst_desc());
|
||||
if (dst_tdesc->isSame(outputMemoryFormatsFilter[0])) {
|
||||
DEBUG_LOG(getName(), " Output memory format filter: ", outputMemoryFormatsFilter[0],
|
||||
" not matched. Erase desc from the list of dnnl primitive descriptors: ", desc);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
descs.erase(std::remove_if(descs.begin(), descs.end(), isNotSuitableDesc), descs.end());
|
||||
}
|
||||
|
||||
std::shared_ptr<MemoryDesc> Convolution::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
|
||||
if (idx == 1) {
|
||||
// report original plain layout for weight since it needs to be reordered dynamically at runtime
|
||||
|
||||
@@ -28,7 +28,6 @@ public:
|
||||
void initDescriptor(const NodeConfig& config) override;
|
||||
void selectOptimalPrimitiveDescriptor() override;
|
||||
void initSupportedPrimitiveDescriptors() override;
|
||||
void filterSupportedPrimitiveDescriptors() override;
|
||||
bool created() const override;
|
||||
bool canBeInPlace() const override {
|
||||
return false;
|
||||
|
||||
@@ -543,48 +543,6 @@ void Deconvolution::setPostOps(dnnl::primitive_attr& attr, const VectorDims& dim
|
||||
attr.set_post_ops(ops);
|
||||
}
|
||||
|
||||
void Deconvolution::filterSupportedPrimitiveDescriptors() {
|
||||
Node::filterSupportedPrimitiveDescriptors();
|
||||
filterSupportedDescriptors();
|
||||
}
|
||||
|
||||
void Deconvolution::filterSupportedDescriptors() {
|
||||
if (inputMemoryFormatsFilter.empty() && outputMemoryFormatsFilter.empty())
|
||||
return;
|
||||
|
||||
if (inputMemoryFormatsFilter.size() > 1 || outputMemoryFormatsFilter.size() > 1)
|
||||
IE_THROW() << "Incorrect number of input or output memory formats for Deconvolution node";
|
||||
|
||||
auto isNotSuitableDesc = [&](const dnnl::primitive_desc& desc) {
|
||||
if (!inputMemoryFormatsFilter.empty()) {
|
||||
auto src_tdesc = isInt8 ? DnnlExtensionUtils::makeDescriptor(desc.src_desc()) :
|
||||
DnnlExtensionUtils::makeDescriptor(desc.diff_src_desc());
|
||||
|
||||
if (!src_tdesc->isSame(inputMemoryFormatsFilter[0])) {
|
||||
DEBUG_LOG(getName(), " input memory format filter: ", inputMemoryFormatsFilter[0],
|
||||
"not matched. Erase desc from the list of dnnl primitive descriptors: ", desc);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
if (!outputMemoryFormatsFilter.empty()) {
|
||||
auto dst_tdesc = isInt8 ? DnnlExtensionUtils::makeDescriptor(desc.dst_desc()) :
|
||||
DnnlExtensionUtils::makeDescriptor(desc.diff_dst_desc());
|
||||
|
||||
if (!dst_tdesc->isSame(outputMemoryFormatsFilter[0])) {
|
||||
DEBUG_LOG(getName(), " Output memory format filter: ", outputMemoryFormatsFilter[0],
|
||||
" not matched. Erase desc from the list of dnnl primitive descriptors: ", desc);
|
||||
|
||||
return true;
|
||||
}
|
||||
}
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
descs.erase(std::remove_if(descs.begin(), descs.end(), isNotSuitableDesc), descs.end());
|
||||
}
|
||||
|
||||
bool Deconvolution::created() const {
|
||||
return getType() == Type::Deconvolution;
|
||||
}
|
||||
|
||||
@@ -23,8 +23,6 @@ public:
|
||||
void createDescriptor(const std::vector<MemoryDescPtr>& inputDesc,
|
||||
const std::vector<MemoryDescPtr>& outputDesc) override;
|
||||
void createPrimitive() override;
|
||||
void filterSupportedPrimitiveDescriptors() override;
|
||||
void filterSupportedDescriptors();
|
||||
bool created() const override;
|
||||
bool canBeInPlace() const override {
|
||||
return false;
|
||||
|
||||
Reference in New Issue
Block a user