[CPU] Fix issue about overwriting register names (#15815)

This commit is contained in:
Chen Xu 2023-03-03 18:50:33 +08:00 committed by GitHub
parent 4e8590bf9b
commit 35dacb370a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 52 deletions

View File

@ -27,39 +27,41 @@ template void RegPrinter::print<unsigned char, Reg16>(jit_generator &h, Reg16 re
template void RegPrinter::print<char, Reg8>(jit_generator &h, Reg8 reg, const char *name); template void RegPrinter::print<char, Reg8>(jit_generator &h, Reg8 reg, const char *name);
template void RegPrinter::print<unsigned char, Reg8>(jit_generator &h, Reg8 reg, const char *name); template void RegPrinter::print<unsigned char, Reg8>(jit_generator &h, Reg8 reg, const char *name);
void RegPrinter::print_reg_fp32(const char *name, int val) {
std::stringstream ss;
ss << name << ": " << *reinterpret_cast<float *>(&val) << std::endl;
std::cout << ss.str();
}
template <typename T> template <typename T>
void RegPrinter::print_reg_integer(const char *name, T val) { void RegPrinter::print_reg_prc(const char *name, const char *ori_name, T *ptr) {
std::stringstream ss; std::stringstream ss;
if (std::is_signed<T>::value) { if (name) ss << name << " | ";
ss << name << ": " << static_cast<int64_t>(val) << std::endl; ss << ori_name << ": ";
if (std::is_floating_point<T>::value) {
ss << *ptr;
} else { } else {
ss << name << ": " << static_cast<uint64_t>(val) << std::endl; if (std::is_signed<T>::value) {
ss << static_cast<int64_t>(*ptr);
} else {
ss << static_cast<uint64_t>(*ptr);
} }
}
ss << std::endl;
std::cout << ss.str(); std::cout << ss.str();
} }
template <typename PRC_T, size_t vlen> template <typename PRC_T, size_t vlen>
void RegPrinter::print_vmm_prc(const char *name, PRC_T *ptr) { void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr) {
std::stringstream ss; std::stringstream ss;
ss << name << ": {" << ptr[0]; if (name) ss << name << " | ";
ss << ori_name << ": {" << ptr[0];
for (size_t i = 1; i < vlen / sizeof(float); i++) { for (size_t i = 1; i < vlen / sizeof(float); i++) {
ss << ", " << ptr[i]; ss << ", " << ptr[i];
} }
ss << "}" << std::endl; ss << "}" << std::endl;
std::cout << ss.str(); std::cout << ss.str();
} }
template void RegPrinter::print_vmm_prc<float, 16>(const char *name, float *ptr); template void RegPrinter::print_vmm_prc<float, 16>(const char *name, const char *ori_name, float *ptr);
template void RegPrinter::print_vmm_prc<float, 32>(const char *name, float *ptr); template void RegPrinter::print_vmm_prc<float, 32>(const char *name, const char *ori_name, float *ptr);
template void RegPrinter::print_vmm_prc<float, 64>(const char *name, float *ptr); template void RegPrinter::print_vmm_prc<float, 64>(const char *name, const char *ori_name, float *ptr);
template void RegPrinter::print_vmm_prc<int, 16>(const char *name, int *ptr); template void RegPrinter::print_vmm_prc<int, 16>(const char *name, const char *ori_name, int *ptr);
template void RegPrinter::print_vmm_prc<int, 32>(const char *name, int *ptr); template void RegPrinter::print_vmm_prc<int, 32>(const char *name, const char *ori_name, int *ptr);
template void RegPrinter::print_vmm_prc<int, 64>(const char *name, int *ptr); template void RegPrinter::print_vmm_prc<int, 64>(const char *name, const char *ori_name, int *ptr);
template <typename Vmm> template <typename Vmm>
struct vmm_traits{}; struct vmm_traits{};
@ -124,42 +126,32 @@ void RegPrinter::postamble(jit_generator &h) {
restore_reg(h); restore_reg(h);
} }
template <typename REG_T> // ABI requires 16-bype stack alignment before a call
const char * RegPrinter::get_name(REG_T reg, const char *name) { void RegPrinter::align_rsp(jit_generator &h) {
const char *reg_name = reg.toString(); constexpr int alignment = 16;
h.mov(h.r15, h.rsp);
if (name == nullptr) { h.and_(h.rsp, ~(alignment - 1));
return reg_name;
} else {
constexpr size_t len = 64;
constexpr size_t aux_len = 3;
static char full_name[len];
size_t total_len = std::strlen(name) + std::strlen(reg_name) + aux_len + 1;
if (total_len > len) {
return reg_name;
} else {
snprintf(full_name, len, "%s | %s", name, reg_name);
return full_name;
}
} }
void RegPrinter::restore_rsp(jit_generator &h) {
h.mov(h.rsp, h.r15);
} }
template <typename PRC_T, typename REG_T> template <typename PRC_T, typename REG_T>
void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) { void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) {
preamble(h); preamble(h);
name = get_name(vmm, name);
h.push(h.rax); h.push(h.rax);
h.push(abi_param1); h.push(abi_param1);
h.push(abi_param2); h.push(abi_param2);
h.push(abi_param3);
{ {
const int vlen = vmm.isZMM() ? 64 : (vmm.isYMM() ? 32 : 16); const int vlen = vmm.isZMM() ? 64 : (vmm.isYMM() ? 32 : 16);
h.sub(h.rsp, vlen); h.sub(h.rsp, vlen);
h.uni_vmovups(h.ptr[h.rsp], vmm); h.uni_vmovups(h.ptr[h.rsp], vmm);
h.mov(abi_param2, h.rsp); h.mov(abi_param3, h.rsp);
h.mov(abi_param2, reinterpret_cast<size_t>(vmm.toString()));
h.mov(abi_param1, reinterpret_cast<size_t>(name)); h.mov(abi_param1, reinterpret_cast<size_t>(name));
if (vmm.isZMM()) { if (vmm.isZMM()) {
h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 64>)); h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 64>));
@ -168,11 +160,14 @@ void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) {
} else { } else {
h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 16>)); h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 16>));
} }
align_rsp(h);
h.call(h.rax); h.call(h.rax);
restore_rsp(h);
h.add(h.rsp, vlen); h.add(h.rsp, vlen);
} }
h.pop(abi_param3);
h.pop(abi_param2); h.pop(abi_param2);
h.pop(abi_param1); h.pop(abi_param1);
h.pop(h.rax); h.pop(h.rax);
@ -184,21 +179,27 @@ template <typename PRC_T, typename REG_T>
void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) { void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) {
preamble(h); preamble(h);
name = get_name(reg, name);
h.push(h.rax); h.push(h.rax);
h.push(abi_param1); h.push(abi_param1);
h.push(abi_param2); h.push(abi_param2);
h.push(abi_param3);
{ {
h.mov(abi_param2, reg); const int rlen = reg.getBit() / 8;
h.sub(h.rsp, rlen);
h.mov(h.ptr[h.rsp], reg);
h.mov(abi_param3, h.rsp);
h.mov(abi_param2, reinterpret_cast<size_t>(reg.toString()));
h.mov(abi_param1, reinterpret_cast<size_t>(name)); h.mov(abi_param1, reinterpret_cast<size_t>(name));
if (std::is_floating_point<PRC_T>::value) h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_prc<PRC_T>));
h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_fp32)); align_rsp(h);
else
h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_integer<PRC_T>));
h.call(h.rax); h.call(h.rax);
restore_rsp(h);
h.add(h.rsp, rlen);
} }
h.pop(abi_param3);
h.pop(abi_param2); h.pop(abi_param2);
h.pop(abi_param1); h.pop(abi_param1);
h.pop(h.rax); h.pop(h.rax);

View File

@ -12,7 +12,10 @@ namespace intel_cpu {
// Usage // Usage
// 1. Include this headfile where JIT kennels of CPU plugin are implemented for Register printing // 1. Include this headfile where JIT kennels of CPU plugin are implemented for Register printing
// 2. Invoke RegPrinter::print method. Here are some examples. Note that user friendly register name // 2. Invoke RegPrinter::print method. Here are some examples. Note that user friendly register name
// will be printed, if it has been set. While original Xbyak register name will always be printed. // will be printed, if it has been set. Current implementation doesn't buffer the name. So if you
// choose to set a name for the register, do not use local variable to pass the name, just pass a
// direct string to the interface like examples. While Original Xbyak register name will always be
// printed.
// Example 1: // Example 1:
// Invocation: RegPrinter::print<float>(*this, vmm_val, "vmm_val"); // Invocation: RegPrinter::print<float>(*this, vmm_val, "vmm_val");
// Console: vmm_val | ymm0: {30, 20, 25, 29, 24, 31, 27, 23} // Console: vmm_val | ymm0: {30, 20, 25, 29, 24, 31, 27, 23}
@ -69,10 +72,9 @@ private:
template <typename PRC_T, typename REG_T> template <typename PRC_T, typename REG_T>
static void print_reg(jit_generator &h, REG_T reg, const char *name); static void print_reg(jit_generator &h, REG_T reg, const char *name);
template <typename PRC_T, size_t vlen> template <typename PRC_T, size_t vlen>
static void print_vmm_prc(const char *name, PRC_T *ptr); static void print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr);
template <typename T> template <typename T>
static void print_reg_integer(const char *name, T val); static void print_reg_prc(const char *name, const char *ori_name, T *val);
static void print_reg_fp32(const char *name, int val);
static void preamble(jit_generator &h); static void preamble(jit_generator &h);
static void postamble(jit_generator &h); static void postamble(jit_generator &h);
template <typename T> template <typename T>
@ -81,8 +83,8 @@ private:
static void restore_vmm(jit_generator &h); static void restore_vmm(jit_generator &h);
static void save_reg(jit_generator &h); static void save_reg(jit_generator &h);
static void restore_reg(jit_generator &h); static void restore_reg(jit_generator &h);
template <typename REG_T> static void align_rsp(jit_generator &h);
static const char * get_name(REG_T reg, const char *name); static void restore_rsp(jit_generator &h);
static constexpr size_t reg_len = 8; static constexpr size_t reg_len = 8;
static constexpr size_t reg_cnt = 16; static constexpr size_t reg_cnt = 16;
}; };