updated to init onednn attr (#19055)
This commit is contained in:
parent
32a6a31de2
commit
9deef1480a
@ -386,8 +386,16 @@ public:
|
||||
std::vector<fused_primitive_desc>& get_fused_primitives() { return fused_prims; }
|
||||
|
||||
#ifdef ENABLE_ONEDNN_FOR_GPU
|
||||
const std::shared_ptr<dnnl::primitive_attr>& get_onednn_primitive_attributes() const { return onednn_attrs; }
|
||||
std::shared_ptr<dnnl::primitive_attr>& get_onednn_primitive_attributes() { return onednn_attrs; }
|
||||
const std::shared_ptr<dnnl::primitive_attr>& get_onednn_primitive_attributes() const {
|
||||
if (onednn_attrs == nullptr)
|
||||
const_cast<program_node*>(this)->init_onednn_primitive_attributes();
|
||||
return onednn_attrs;
|
||||
}
|
||||
std::shared_ptr<dnnl::primitive_attr>& get_onednn_primitive_attributes() {
|
||||
if (onednn_attrs == nullptr)
|
||||
init_onednn_primitive_attributes();
|
||||
return onednn_attrs;
|
||||
}
|
||||
|
||||
const std::vector<fused_primitive_desc_onednn>& get_fused_primitives_onednn() const { return fused_prims_onednn; }
|
||||
std::vector<fused_primitive_desc_onednn>& get_fused_primitives_onednn() { return fused_prims_onednn; }
|
||||
|
Loading…
Reference in New Issue
Block a user