[CPU] Fix issue about overwriting register names (#15815)
This commit is contained in:
parent
4e8590bf9b
commit
35dacb370a
@ -27,39 +27,41 @@ template void RegPrinter::print<unsigned char, Reg16>(jit_generator &h, Reg16 re
|
||||
template void RegPrinter::print<char, Reg8>(jit_generator &h, Reg8 reg, const char *name);
|
||||
template void RegPrinter::print<unsigned char, Reg8>(jit_generator &h, Reg8 reg, const char *name);
|
||||
|
||||
void RegPrinter::print_reg_fp32(const char *name, int val) {
|
||||
std::stringstream ss;
|
||||
ss << name << ": " << *reinterpret_cast<float *>(&val) << std::endl;
|
||||
std::cout << ss.str();
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
void RegPrinter::print_reg_integer(const char *name, T val) {
|
||||
void RegPrinter::print_reg_prc(const char *name, const char *ori_name, T *ptr) {
|
||||
std::stringstream ss;
|
||||
if (std::is_signed<T>::value) {
|
||||
ss << name << ": " << static_cast<int64_t>(val) << std::endl;
|
||||
if (name) ss << name << " | ";
|
||||
ss << ori_name << ": ";
|
||||
if (std::is_floating_point<T>::value) {
|
||||
ss << *ptr;
|
||||
} else {
|
||||
ss << name << ": " << static_cast<uint64_t>(val) << std::endl;
|
||||
if (std::is_signed<T>::value) {
|
||||
ss << static_cast<int64_t>(*ptr);
|
||||
} else {
|
||||
ss << static_cast<uint64_t>(*ptr);
|
||||
}
|
||||
}
|
||||
ss << std::endl;
|
||||
std::cout << ss.str();
|
||||
}
|
||||
|
||||
template <typename PRC_T, size_t vlen>
|
||||
void RegPrinter::print_vmm_prc(const char *name, PRC_T *ptr) {
|
||||
void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr) {
|
||||
std::stringstream ss;
|
||||
ss << name << ": {" << ptr[0];
|
||||
if (name) ss << name << " | ";
|
||||
ss << ori_name << ": {" << ptr[0];
|
||||
for (size_t i = 1; i < vlen / sizeof(float); i++) {
|
||||
ss << ", " << ptr[i];
|
||||
}
|
||||
ss << "}" << std::endl;
|
||||
std::cout << ss.str();
|
||||
}
|
||||
template void RegPrinter::print_vmm_prc<float, 16>(const char *name, float *ptr);
|
||||
template void RegPrinter::print_vmm_prc<float, 32>(const char *name, float *ptr);
|
||||
template void RegPrinter::print_vmm_prc<float, 64>(const char *name, float *ptr);
|
||||
template void RegPrinter::print_vmm_prc<int, 16>(const char *name, int *ptr);
|
||||
template void RegPrinter::print_vmm_prc<int, 32>(const char *name, int *ptr);
|
||||
template void RegPrinter::print_vmm_prc<int, 64>(const char *name, int *ptr);
|
||||
template void RegPrinter::print_vmm_prc<float, 16>(const char *name, const char *ori_name, float *ptr);
|
||||
template void RegPrinter::print_vmm_prc<float, 32>(const char *name, const char *ori_name, float *ptr);
|
||||
template void RegPrinter::print_vmm_prc<float, 64>(const char *name, const char *ori_name, float *ptr);
|
||||
template void RegPrinter::print_vmm_prc<int, 16>(const char *name, const char *ori_name, int *ptr);
|
||||
template void RegPrinter::print_vmm_prc<int, 32>(const char *name, const char *ori_name, int *ptr);
|
||||
template void RegPrinter::print_vmm_prc<int, 64>(const char *name, const char *ori_name, int *ptr);
|
||||
|
||||
template <typename Vmm>
|
||||
struct vmm_traits{};
|
||||
@ -124,42 +126,32 @@ void RegPrinter::postamble(jit_generator &h) {
|
||||
restore_reg(h);
|
||||
}
|
||||
|
||||
template <typename REG_T>
|
||||
const char * RegPrinter::get_name(REG_T reg, const char *name) {
|
||||
const char *reg_name = reg.toString();
|
||||
// ABI requires 16-bype stack alignment before a call
|
||||
void RegPrinter::align_rsp(jit_generator &h) {
|
||||
constexpr int alignment = 16;
|
||||
h.mov(h.r15, h.rsp);
|
||||
h.and_(h.rsp, ~(alignment - 1));
|
||||
}
|
||||
|
||||
if (name == nullptr) {
|
||||
return reg_name;
|
||||
} else {
|
||||
constexpr size_t len = 64;
|
||||
constexpr size_t aux_len = 3;
|
||||
static char full_name[len];
|
||||
|
||||
size_t total_len = std::strlen(name) + std::strlen(reg_name) + aux_len + 1;
|
||||
if (total_len > len) {
|
||||
return reg_name;
|
||||
} else {
|
||||
snprintf(full_name, len, "%s | %s", name, reg_name);
|
||||
return full_name;
|
||||
}
|
||||
}
|
||||
void RegPrinter::restore_rsp(jit_generator &h) {
|
||||
h.mov(h.rsp, h.r15);
|
||||
}
|
||||
|
||||
template <typename PRC_T, typename REG_T>
|
||||
void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) {
|
||||
preamble(h);
|
||||
|
||||
name = get_name(vmm, name);
|
||||
|
||||
h.push(h.rax);
|
||||
h.push(abi_param1);
|
||||
h.push(abi_param2);
|
||||
h.push(abi_param3);
|
||||
{
|
||||
const int vlen = vmm.isZMM() ? 64 : (vmm.isYMM() ? 32 : 16);
|
||||
h.sub(h.rsp, vlen);
|
||||
h.uni_vmovups(h.ptr[h.rsp], vmm);
|
||||
|
||||
h.mov(abi_param2, h.rsp);
|
||||
h.mov(abi_param3, h.rsp);
|
||||
h.mov(abi_param2, reinterpret_cast<size_t>(vmm.toString()));
|
||||
h.mov(abi_param1, reinterpret_cast<size_t>(name));
|
||||
if (vmm.isZMM()) {
|
||||
h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 64>));
|
||||
@ -168,11 +160,14 @@ void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) {
|
||||
} else {
|
||||
h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 16>));
|
||||
}
|
||||
align_rsp(h);
|
||||
h.call(h.rax);
|
||||
restore_rsp(h);
|
||||
|
||||
h.add(h.rsp, vlen);
|
||||
}
|
||||
|
||||
h.pop(abi_param3);
|
||||
h.pop(abi_param2);
|
||||
h.pop(abi_param1);
|
||||
h.pop(h.rax);
|
||||
@ -184,21 +179,27 @@ template <typename PRC_T, typename REG_T>
|
||||
void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) {
|
||||
preamble(h);
|
||||
|
||||
name = get_name(reg, name);
|
||||
|
||||
h.push(h.rax);
|
||||
h.push(abi_param1);
|
||||
h.push(abi_param2);
|
||||
h.push(abi_param3);
|
||||
{
|
||||
h.mov(abi_param2, reg);
|
||||
const int rlen = reg.getBit() / 8;
|
||||
h.sub(h.rsp, rlen);
|
||||
h.mov(h.ptr[h.rsp], reg);
|
||||
|
||||
h.mov(abi_param3, h.rsp);
|
||||
h.mov(abi_param2, reinterpret_cast<size_t>(reg.toString()));
|
||||
h.mov(abi_param1, reinterpret_cast<size_t>(name));
|
||||
if (std::is_floating_point<PRC_T>::value)
|
||||
h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_fp32));
|
||||
else
|
||||
h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_integer<PRC_T>));
|
||||
h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_prc<PRC_T>));
|
||||
align_rsp(h);
|
||||
h.call(h.rax);
|
||||
restore_rsp(h);
|
||||
|
||||
h.add(h.rsp, rlen);
|
||||
}
|
||||
|
||||
h.pop(abi_param3);
|
||||
h.pop(abi_param2);
|
||||
h.pop(abi_param1);
|
||||
h.pop(h.rax);
|
||||
|
@ -12,7 +12,10 @@ namespace intel_cpu {
|
||||
// Usage
|
||||
// 1. Include this headfile where JIT kennels of CPU plugin are implemented for Register printing
|
||||
// 2. Invoke RegPrinter::print method. Here are some examples. Note that user friendly register name
|
||||
// will be printed, if it has been set. While original Xbyak register name will always be printed.
|
||||
// will be printed, if it has been set. Current implementation doesn't buffer the name. So if you
|
||||
// choose to set a name for the register, do not use local variable to pass the name, just pass a
|
||||
// direct string to the interface like examples. While Original Xbyak register name will always be
|
||||
// printed.
|
||||
// Example 1:
|
||||
// Invocation: RegPrinter::print<float>(*this, vmm_val, "vmm_val");
|
||||
// Console: vmm_val | ymm0: {30, 20, 25, 29, 24, 31, 27, 23}
|
||||
@ -69,10 +72,9 @@ private:
|
||||
template <typename PRC_T, typename REG_T>
|
||||
static void print_reg(jit_generator &h, REG_T reg, const char *name);
|
||||
template <typename PRC_T, size_t vlen>
|
||||
static void print_vmm_prc(const char *name, PRC_T *ptr);
|
||||
static void print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr);
|
||||
template <typename T>
|
||||
static void print_reg_integer(const char *name, T val);
|
||||
static void print_reg_fp32(const char *name, int val);
|
||||
static void print_reg_prc(const char *name, const char *ori_name, T *val);
|
||||
static void preamble(jit_generator &h);
|
||||
static void postamble(jit_generator &h);
|
||||
template <typename T>
|
||||
@ -81,8 +83,8 @@ private:
|
||||
static void restore_vmm(jit_generator &h);
|
||||
static void save_reg(jit_generator &h);
|
||||
static void restore_reg(jit_generator &h);
|
||||
template <typename REG_T>
|
||||
static const char * get_name(REG_T reg, const char *name);
|
||||
static void align_rsp(jit_generator &h);
|
||||
static void restore_rsp(jit_generator &h);
|
||||
static constexpr size_t reg_len = 8;
|
||||
static constexpr size_t reg_cnt = 16;
|
||||
};
|
||||
|
Loading…
Reference in New Issue
Block a user