From 35dacb370a5d4cb4865e0ab7be7fbee2effb9efb Mon Sep 17 00:00:00 2001 From: Chen Xu Date: Fri, 3 Mar 2023 18:50:33 +0800 Subject: [PATCH] [CPU] Fix issue about overwriting register names (#15815) --- src/plugins/intel_cpu/src/emitters/utils.cpp | 93 ++++++++++---------- src/plugins/intel_cpu/src/emitters/utils.hpp | 14 +-- 2 files changed, 55 insertions(+), 52 deletions(-) diff --git a/src/plugins/intel_cpu/src/emitters/utils.cpp b/src/plugins/intel_cpu/src/emitters/utils.cpp index 19cb6c30a0c..236be4118e3 100644 --- a/src/plugins/intel_cpu/src/emitters/utils.cpp +++ b/src/plugins/intel_cpu/src/emitters/utils.cpp @@ -27,39 +27,41 @@ template void RegPrinter::print(jit_generator &h, Reg16 re template void RegPrinter::print(jit_generator &h, Reg8 reg, const char *name); template void RegPrinter::print(jit_generator &h, Reg8 reg, const char *name); -void RegPrinter::print_reg_fp32(const char *name, int val) { - std::stringstream ss; - ss << name << ": " << *reinterpret_cast(&val) << std::endl; - std::cout << ss.str(); -} - template -void RegPrinter::print_reg_integer(const char *name, T val) { +void RegPrinter::print_reg_prc(const char *name, const char *ori_name, T *ptr) { std::stringstream ss; - if (std::is_signed::value) { - ss << name << ": " << static_cast(val) << std::endl; + if (name) ss << name << " | "; + ss << ori_name << ": "; + if (std::is_floating_point::value) { + ss << *ptr; } else { - ss << name << ": " << static_cast(val) << std::endl; + if (std::is_signed::value) { + ss << static_cast(*ptr); + } else { + ss << static_cast(*ptr); + } } + ss << std::endl; std::cout << ss.str(); } template -void RegPrinter::print_vmm_prc(const char *name, PRC_T *ptr) { +void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr) { std::stringstream ss; - ss << name << ": {" << ptr[0]; + if (name) ss << name << " | "; + ss << ori_name << ": {" << ptr[0]; for (size_t i = 1; i < vlen / sizeof(float); i++) { ss << ", " << ptr[i]; } ss << "}" << std::endl; std::cout << ss.str(); } -template void RegPrinter::print_vmm_prc(const char *name, float *ptr); -template void RegPrinter::print_vmm_prc(const char *name, float *ptr); -template void RegPrinter::print_vmm_prc(const char *name, float *ptr); -template void RegPrinter::print_vmm_prc(const char *name, int *ptr); -template void RegPrinter::print_vmm_prc(const char *name, int *ptr); -template void RegPrinter::print_vmm_prc(const char *name, int *ptr); +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, float *ptr); +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); +template void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, int *ptr); template struct vmm_traits{}; @@ -124,42 +126,32 @@ void RegPrinter::postamble(jit_generator &h) { restore_reg(h); } -template -const char * RegPrinter::get_name(REG_T reg, const char *name) { - const char *reg_name = reg.toString(); +// ABI requires 16-bype stack alignment before a call +void RegPrinter::align_rsp(jit_generator &h) { + constexpr int alignment = 16; + h.mov(h.r15, h.rsp); + h.and_(h.rsp, ~(alignment - 1)); +} - if (name == nullptr) { - return reg_name; - } else { - constexpr size_t len = 64; - constexpr size_t aux_len = 3; - static char full_name[len]; - - size_t total_len = std::strlen(name) + std::strlen(reg_name) + aux_len + 1; - if (total_len > len) { - return reg_name; - } else { - snprintf(full_name, len, "%s | %s", name, reg_name); - return full_name; - } - } +void RegPrinter::restore_rsp(jit_generator &h) { + h.mov(h.rsp, h.r15); } template void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) { preamble(h); - name = get_name(vmm, name); - h.push(h.rax); h.push(abi_param1); h.push(abi_param2); + h.push(abi_param3); { const int vlen = vmm.isZMM() ? 64 : (vmm.isYMM() ? 32 : 16); h.sub(h.rsp, vlen); h.uni_vmovups(h.ptr[h.rsp], vmm); - h.mov(abi_param2, h.rsp); + h.mov(abi_param3, h.rsp); + h.mov(abi_param2, reinterpret_cast(vmm.toString())); h.mov(abi_param1, reinterpret_cast(name)); if (vmm.isZMM()) { h.mov(h.rax, reinterpret_cast(&print_vmm_prc)); @@ -168,11 +160,14 @@ void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) { } else { h.mov(h.rax, reinterpret_cast(&print_vmm_prc)); } + align_rsp(h); h.call(h.rax); + restore_rsp(h); h.add(h.rsp, vlen); } + h.pop(abi_param3); h.pop(abi_param2); h.pop(abi_param1); h.pop(h.rax); @@ -184,21 +179,27 @@ template void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) { preamble(h); - name = get_name(reg, name); - h.push(h.rax); h.push(abi_param1); h.push(abi_param2); + h.push(abi_param3); { - h.mov(abi_param2, reg); + const int rlen = reg.getBit() / 8; + h.sub(h.rsp, rlen); + h.mov(h.ptr[h.rsp], reg); + + h.mov(abi_param3, h.rsp); + h.mov(abi_param2, reinterpret_cast(reg.toString())); h.mov(abi_param1, reinterpret_cast(name)); - if (std::is_floating_point::value) - h.mov(h.rax, reinterpret_cast(&print_reg_fp32)); - else - h.mov(h.rax, reinterpret_cast(&print_reg_integer)); + h.mov(h.rax, reinterpret_cast(&print_reg_prc)); + align_rsp(h); h.call(h.rax); + restore_rsp(h); + + h.add(h.rsp, rlen); } + h.pop(abi_param3); h.pop(abi_param2); h.pop(abi_param1); h.pop(h.rax); diff --git a/src/plugins/intel_cpu/src/emitters/utils.hpp b/src/plugins/intel_cpu/src/emitters/utils.hpp index d14b35819f5..af79d08c867 100644 --- a/src/plugins/intel_cpu/src/emitters/utils.hpp +++ b/src/plugins/intel_cpu/src/emitters/utils.hpp @@ -12,7 +12,10 @@ namespace intel_cpu { // Usage // 1. Include this headfile where JIT kennels of CPU plugin are implemented for Register printing // 2. Invoke RegPrinter::print method. Here are some examples. Note that user friendly register name -// will be printed, if it has been set. While original Xbyak register name will always be printed. +// will be printed, if it has been set. Current implementation doesn't buffer the name. So if you +// choose to set a name for the register, do not use local variable to pass the name, just pass a +// direct string to the interface like examples. While Original Xbyak register name will always be +// printed. // Example 1: // Invocation: RegPrinter::print(*this, vmm_val, "vmm_val"); // Console: vmm_val | ymm0: {30, 20, 25, 29, 24, 31, 27, 23} @@ -69,10 +72,9 @@ private: template static void print_reg(jit_generator &h, REG_T reg, const char *name); template - static void print_vmm_prc(const char *name, PRC_T *ptr); + static void print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr); template - static void print_reg_integer(const char *name, T val); - static void print_reg_fp32(const char *name, int val); + static void print_reg_prc(const char *name, const char *ori_name, T *val); static void preamble(jit_generator &h); static void postamble(jit_generator &h); template @@ -81,8 +83,8 @@ private: static void restore_vmm(jit_generator &h); static void save_reg(jit_generator &h); static void restore_reg(jit_generator &h); - template - static const char * get_name(REG_T reg, const char *name); + static void align_rsp(jit_generator &h); + static void restore_rsp(jit_generator &h); static constexpr size_t reg_len = 8; static constexpr size_t reg_cnt = 16; };