[CPU] Fix issue about overwriting register names (#15815)

This commit is contained in:
Chen Xu 2023-03-03 18:50:33 +08:00 committed by GitHub
parent 4e8590bf9b
commit 35dacb370a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 55 additions and 52 deletions

View File

@ -27,39 +27,41 @@ template void RegPrinter::print<unsigned char, Reg16>(jit_generator &h, Reg16 re
template void RegPrinter::print<char, Reg8>(jit_generator &h, Reg8 reg, const char *name);
template void RegPrinter::print<unsigned char, Reg8>(jit_generator &h, Reg8 reg, const char *name);
void RegPrinter::print_reg_fp32(const char *name, int val) {
std::stringstream ss;
ss << name << ": " << *reinterpret_cast<float *>(&val) << std::endl;
std::cout << ss.str();
}
template <typename T>
void RegPrinter::print_reg_integer(const char *name, T val) {
void RegPrinter::print_reg_prc(const char *name, const char *ori_name, T *ptr) {
std::stringstream ss;
if (std::is_signed<T>::value) {
ss << name << ": " << static_cast<int64_t>(val) << std::endl;
if (name) ss << name << " | ";
ss << ori_name << ": ";
if (std::is_floating_point<T>::value) {
ss << *ptr;
} else {
ss << name << ": " << static_cast<uint64_t>(val) << std::endl;
if (std::is_signed<T>::value) {
ss << static_cast<int64_t>(*ptr);
} else {
ss << static_cast<uint64_t>(*ptr);
}
}
ss << std::endl;
std::cout << ss.str();
}
template <typename PRC_T, size_t vlen>
void RegPrinter::print_vmm_prc(const char *name, PRC_T *ptr) {
void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr) {
std::stringstream ss;
ss << name << ": {" << ptr[0];
if (name) ss << name << " | ";
ss << ori_name << ": {" << ptr[0];
for (size_t i = 1; i < vlen / sizeof(float); i++) {
ss << ", " << ptr[i];
}
ss << "}" << std::endl;
std::cout << ss.str();
}
template void RegPrinter::print_vmm_prc<float, 16>(const char *name, float *ptr);
template void RegPrinter::print_vmm_prc<float, 32>(const char *name, float *ptr);
template void RegPrinter::print_vmm_prc<float, 64>(const char *name, float *ptr);
template void RegPrinter::print_vmm_prc<int, 16>(const char *name, int *ptr);
template void RegPrinter::print_vmm_prc<int, 32>(const char *name, int *ptr);
template void RegPrinter::print_vmm_prc<int, 64>(const char *name, int *ptr);
template void RegPrinter::print_vmm_prc<float, 16>(const char *name, const char *ori_name, float *ptr);
template void RegPrinter::print_vmm_prc<float, 32>(const char *name, const char *ori_name, float *ptr);
template void RegPrinter::print_vmm_prc<float, 64>(const char *name, const char *ori_name, float *ptr);
template void RegPrinter::print_vmm_prc<int, 16>(const char *name, const char *ori_name, int *ptr);
template void RegPrinter::print_vmm_prc<int, 32>(const char *name, const char *ori_name, int *ptr);
template void RegPrinter::print_vmm_prc<int, 64>(const char *name, const char *ori_name, int *ptr);
template <typename Vmm>
struct vmm_traits{};
@ -124,42 +126,32 @@ void RegPrinter::postamble(jit_generator &h) {
restore_reg(h);
}
template <typename REG_T>
const char * RegPrinter::get_name(REG_T reg, const char *name) {
const char *reg_name = reg.toString();
// ABI requires 16-bype stack alignment before a call
void RegPrinter::align_rsp(jit_generator &h) {
constexpr int alignment = 16;
h.mov(h.r15, h.rsp);
h.and_(h.rsp, ~(alignment - 1));
}
if (name == nullptr) {
return reg_name;
} else {
constexpr size_t len = 64;
constexpr size_t aux_len = 3;
static char full_name[len];
size_t total_len = std::strlen(name) + std::strlen(reg_name) + aux_len + 1;
if (total_len > len) {
return reg_name;
} else {
snprintf(full_name, len, "%s | %s", name, reg_name);
return full_name;
}
}
void RegPrinter::restore_rsp(jit_generator &h) {
h.mov(h.rsp, h.r15);
}
template <typename PRC_T, typename REG_T>
void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) {
preamble(h);
name = get_name(vmm, name);
h.push(h.rax);
h.push(abi_param1);
h.push(abi_param2);
h.push(abi_param3);
{
const int vlen = vmm.isZMM() ? 64 : (vmm.isYMM() ? 32 : 16);
h.sub(h.rsp, vlen);
h.uni_vmovups(h.ptr[h.rsp], vmm);
h.mov(abi_param2, h.rsp);
h.mov(abi_param3, h.rsp);
h.mov(abi_param2, reinterpret_cast<size_t>(vmm.toString()));
h.mov(abi_param1, reinterpret_cast<size_t>(name));
if (vmm.isZMM()) {
h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 64>));
@ -168,11 +160,14 @@ void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) {
} else {
h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 16>));
}
align_rsp(h);
h.call(h.rax);
restore_rsp(h);
h.add(h.rsp, vlen);
}
h.pop(abi_param3);
h.pop(abi_param2);
h.pop(abi_param1);
h.pop(h.rax);
@ -184,21 +179,27 @@ template <typename PRC_T, typename REG_T>
void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) {
preamble(h);
name = get_name(reg, name);
h.push(h.rax);
h.push(abi_param1);
h.push(abi_param2);
h.push(abi_param3);
{
h.mov(abi_param2, reg);
const int rlen = reg.getBit() / 8;
h.sub(h.rsp, rlen);
h.mov(h.ptr[h.rsp], reg);
h.mov(abi_param3, h.rsp);
h.mov(abi_param2, reinterpret_cast<size_t>(reg.toString()));
h.mov(abi_param1, reinterpret_cast<size_t>(name));
if (std::is_floating_point<PRC_T>::value)
h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_fp32));
else
h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_integer<PRC_T>));
h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_prc<PRC_T>));
align_rsp(h);
h.call(h.rax);
restore_rsp(h);
h.add(h.rsp, rlen);
}
h.pop(abi_param3);
h.pop(abi_param2);
h.pop(abi_param1);
h.pop(h.rax);

View File

@ -12,7 +12,10 @@ namespace intel_cpu {
// Usage
// 1. Include this headfile where JIT kennels of CPU plugin are implemented for Register printing
// 2. Invoke RegPrinter::print method. Here are some examples. Note that user friendly register name
// will be printed, if it has been set. While original Xbyak register name will always be printed.
// will be printed, if it has been set. Current implementation doesn't buffer the name. So if you
// choose to set a name for the register, do not use local variable to pass the name, just pass a
// direct string to the interface like examples. While Original Xbyak register name will always be
// printed.
// Example 1:
// Invocation: RegPrinter::print<float>(*this, vmm_val, "vmm_val");
// Console: vmm_val | ymm0: {30, 20, 25, 29, 24, 31, 27, 23}
@ -69,10 +72,9 @@ private:
template <typename PRC_T, typename REG_T>
static void print_reg(jit_generator &h, REG_T reg, const char *name);
template <typename PRC_T, size_t vlen>
static void print_vmm_prc(const char *name, PRC_T *ptr);
static void print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr);
template <typename T>
static void print_reg_integer(const char *name, T val);
static void print_reg_fp32(const char *name, int val);
static void print_reg_prc(const char *name, const char *ori_name, T *val);
static void preamble(jit_generator &h);
static void postamble(jit_generator &h);
template <typename T>
@ -81,8 +83,8 @@ private:
static void restore_vmm(jit_generator &h);
static void save_reg(jit_generator &h);
static void restore_reg(jit_generator &h);
template <typename REG_T>
static const char * get_name(REG_T reg, const char *name);
static void align_rsp(jit_generator &h);
static void restore_rsp(jit_generator &h);
static constexpr size_t reg_len = 8;
static constexpr size_t reg_cnt = 16;
};