[CPU] Fix issue about overwriting register names (#15815)
This commit is contained in:
parent
4e8590bf9b
commit
35dacb370a
@ -27,39 +27,41 @@ template void RegPrinter::print<unsigned char, Reg16>(jit_generator &h, Reg16 re
|
|||||||
template void RegPrinter::print<char, Reg8>(jit_generator &h, Reg8 reg, const char *name);
|
template void RegPrinter::print<char, Reg8>(jit_generator &h, Reg8 reg, const char *name);
|
||||||
template void RegPrinter::print<unsigned char, Reg8>(jit_generator &h, Reg8 reg, const char *name);
|
template void RegPrinter::print<unsigned char, Reg8>(jit_generator &h, Reg8 reg, const char *name);
|
||||||
|
|
||||||
void RegPrinter::print_reg_fp32(const char *name, int val) {
|
|
||||||
std::stringstream ss;
|
|
||||||
ss << name << ": " << *reinterpret_cast<float *>(&val) << std::endl;
|
|
||||||
std::cout << ss.str();
|
|
||||||
}
|
|
||||||
|
|
||||||
template <typename T>
|
template <typename T>
|
||||||
void RegPrinter::print_reg_integer(const char *name, T val) {
|
void RegPrinter::print_reg_prc(const char *name, const char *ori_name, T *ptr) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
if (std::is_signed<T>::value) {
|
if (name) ss << name << " | ";
|
||||||
ss << name << ": " << static_cast<int64_t>(val) << std::endl;
|
ss << ori_name << ": ";
|
||||||
|
if (std::is_floating_point<T>::value) {
|
||||||
|
ss << *ptr;
|
||||||
} else {
|
} else {
|
||||||
ss << name << ": " << static_cast<uint64_t>(val) << std::endl;
|
if (std::is_signed<T>::value) {
|
||||||
|
ss << static_cast<int64_t>(*ptr);
|
||||||
|
} else {
|
||||||
|
ss << static_cast<uint64_t>(*ptr);
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
ss << std::endl;
|
||||||
std::cout << ss.str();
|
std::cout << ss.str();
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename PRC_T, size_t vlen>
|
template <typename PRC_T, size_t vlen>
|
||||||
void RegPrinter::print_vmm_prc(const char *name, PRC_T *ptr) {
|
void RegPrinter::print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr) {
|
||||||
std::stringstream ss;
|
std::stringstream ss;
|
||||||
ss << name << ": {" << ptr[0];
|
if (name) ss << name << " | ";
|
||||||
|
ss << ori_name << ": {" << ptr[0];
|
||||||
for (size_t i = 1; i < vlen / sizeof(float); i++) {
|
for (size_t i = 1; i < vlen / sizeof(float); i++) {
|
||||||
ss << ", " << ptr[i];
|
ss << ", " << ptr[i];
|
||||||
}
|
}
|
||||||
ss << "}" << std::endl;
|
ss << "}" << std::endl;
|
||||||
std::cout << ss.str();
|
std::cout << ss.str();
|
||||||
}
|
}
|
||||||
template void RegPrinter::print_vmm_prc<float, 16>(const char *name, float *ptr);
|
template void RegPrinter::print_vmm_prc<float, 16>(const char *name, const char *ori_name, float *ptr);
|
||||||
template void RegPrinter::print_vmm_prc<float, 32>(const char *name, float *ptr);
|
template void RegPrinter::print_vmm_prc<float, 32>(const char *name, const char *ori_name, float *ptr);
|
||||||
template void RegPrinter::print_vmm_prc<float, 64>(const char *name, float *ptr);
|
template void RegPrinter::print_vmm_prc<float, 64>(const char *name, const char *ori_name, float *ptr);
|
||||||
template void RegPrinter::print_vmm_prc<int, 16>(const char *name, int *ptr);
|
template void RegPrinter::print_vmm_prc<int, 16>(const char *name, const char *ori_name, int *ptr);
|
||||||
template void RegPrinter::print_vmm_prc<int, 32>(const char *name, int *ptr);
|
template void RegPrinter::print_vmm_prc<int, 32>(const char *name, const char *ori_name, int *ptr);
|
||||||
template void RegPrinter::print_vmm_prc<int, 64>(const char *name, int *ptr);
|
template void RegPrinter::print_vmm_prc<int, 64>(const char *name, const char *ori_name, int *ptr);
|
||||||
|
|
||||||
template <typename Vmm>
|
template <typename Vmm>
|
||||||
struct vmm_traits{};
|
struct vmm_traits{};
|
||||||
@ -124,42 +126,32 @@ void RegPrinter::postamble(jit_generator &h) {
|
|||||||
restore_reg(h);
|
restore_reg(h);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename REG_T>
|
// ABI requires 16-bype stack alignment before a call
|
||||||
const char * RegPrinter::get_name(REG_T reg, const char *name) {
|
void RegPrinter::align_rsp(jit_generator &h) {
|
||||||
const char *reg_name = reg.toString();
|
constexpr int alignment = 16;
|
||||||
|
h.mov(h.r15, h.rsp);
|
||||||
if (name == nullptr) {
|
h.and_(h.rsp, ~(alignment - 1));
|
||||||
return reg_name;
|
|
||||||
} else {
|
|
||||||
constexpr size_t len = 64;
|
|
||||||
constexpr size_t aux_len = 3;
|
|
||||||
static char full_name[len];
|
|
||||||
|
|
||||||
size_t total_len = std::strlen(name) + std::strlen(reg_name) + aux_len + 1;
|
|
||||||
if (total_len > len) {
|
|
||||||
return reg_name;
|
|
||||||
} else {
|
|
||||||
snprintf(full_name, len, "%s | %s", name, reg_name);
|
|
||||||
return full_name;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
void RegPrinter::restore_rsp(jit_generator &h) {
|
||||||
|
h.mov(h.rsp, h.r15);
|
||||||
}
|
}
|
||||||
|
|
||||||
template <typename PRC_T, typename REG_T>
|
template <typename PRC_T, typename REG_T>
|
||||||
void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) {
|
void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) {
|
||||||
preamble(h);
|
preamble(h);
|
||||||
|
|
||||||
name = get_name(vmm, name);
|
|
||||||
|
|
||||||
h.push(h.rax);
|
h.push(h.rax);
|
||||||
h.push(abi_param1);
|
h.push(abi_param1);
|
||||||
h.push(abi_param2);
|
h.push(abi_param2);
|
||||||
|
h.push(abi_param3);
|
||||||
{
|
{
|
||||||
const int vlen = vmm.isZMM() ? 64 : (vmm.isYMM() ? 32 : 16);
|
const int vlen = vmm.isZMM() ? 64 : (vmm.isYMM() ? 32 : 16);
|
||||||
h.sub(h.rsp, vlen);
|
h.sub(h.rsp, vlen);
|
||||||
h.uni_vmovups(h.ptr[h.rsp], vmm);
|
h.uni_vmovups(h.ptr[h.rsp], vmm);
|
||||||
|
|
||||||
h.mov(abi_param2, h.rsp);
|
h.mov(abi_param3, h.rsp);
|
||||||
|
h.mov(abi_param2, reinterpret_cast<size_t>(vmm.toString()));
|
||||||
h.mov(abi_param1, reinterpret_cast<size_t>(name));
|
h.mov(abi_param1, reinterpret_cast<size_t>(name));
|
||||||
if (vmm.isZMM()) {
|
if (vmm.isZMM()) {
|
||||||
h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 64>));
|
h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 64>));
|
||||||
@ -168,11 +160,14 @@ void RegPrinter::print_vmm(jit_generator &h, REG_T vmm, const char *name) {
|
|||||||
} else {
|
} else {
|
||||||
h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 16>));
|
h.mov(h.rax, reinterpret_cast<size_t>(&print_vmm_prc<PRC_T, 16>));
|
||||||
}
|
}
|
||||||
|
align_rsp(h);
|
||||||
h.call(h.rax);
|
h.call(h.rax);
|
||||||
|
restore_rsp(h);
|
||||||
|
|
||||||
h.add(h.rsp, vlen);
|
h.add(h.rsp, vlen);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.pop(abi_param3);
|
||||||
h.pop(abi_param2);
|
h.pop(abi_param2);
|
||||||
h.pop(abi_param1);
|
h.pop(abi_param1);
|
||||||
h.pop(h.rax);
|
h.pop(h.rax);
|
||||||
@ -184,21 +179,27 @@ template <typename PRC_T, typename REG_T>
|
|||||||
void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) {
|
void RegPrinter::print_reg(jit_generator &h, REG_T reg, const char *name) {
|
||||||
preamble(h);
|
preamble(h);
|
||||||
|
|
||||||
name = get_name(reg, name);
|
|
||||||
|
|
||||||
h.push(h.rax);
|
h.push(h.rax);
|
||||||
h.push(abi_param1);
|
h.push(abi_param1);
|
||||||
h.push(abi_param2);
|
h.push(abi_param2);
|
||||||
|
h.push(abi_param3);
|
||||||
{
|
{
|
||||||
h.mov(abi_param2, reg);
|
const int rlen = reg.getBit() / 8;
|
||||||
|
h.sub(h.rsp, rlen);
|
||||||
|
h.mov(h.ptr[h.rsp], reg);
|
||||||
|
|
||||||
|
h.mov(abi_param3, h.rsp);
|
||||||
|
h.mov(abi_param2, reinterpret_cast<size_t>(reg.toString()));
|
||||||
h.mov(abi_param1, reinterpret_cast<size_t>(name));
|
h.mov(abi_param1, reinterpret_cast<size_t>(name));
|
||||||
if (std::is_floating_point<PRC_T>::value)
|
h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_prc<PRC_T>));
|
||||||
h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_fp32));
|
align_rsp(h);
|
||||||
else
|
|
||||||
h.mov(h.rax, reinterpret_cast<size_t>(&print_reg_integer<PRC_T>));
|
|
||||||
h.call(h.rax);
|
h.call(h.rax);
|
||||||
|
restore_rsp(h);
|
||||||
|
|
||||||
|
h.add(h.rsp, rlen);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h.pop(abi_param3);
|
||||||
h.pop(abi_param2);
|
h.pop(abi_param2);
|
||||||
h.pop(abi_param1);
|
h.pop(abi_param1);
|
||||||
h.pop(h.rax);
|
h.pop(h.rax);
|
||||||
|
@ -12,7 +12,10 @@ namespace intel_cpu {
|
|||||||
// Usage
|
// Usage
|
||||||
// 1. Include this headfile where JIT kennels of CPU plugin are implemented for Register printing
|
// 1. Include this headfile where JIT kennels of CPU plugin are implemented for Register printing
|
||||||
// 2. Invoke RegPrinter::print method. Here are some examples. Note that user friendly register name
|
// 2. Invoke RegPrinter::print method. Here are some examples. Note that user friendly register name
|
||||||
// will be printed, if it has been set. While original Xbyak register name will always be printed.
|
// will be printed, if it has been set. Current implementation doesn't buffer the name. So if you
|
||||||
|
// choose to set a name for the register, do not use local variable to pass the name, just pass a
|
||||||
|
// direct string to the interface like examples. While Original Xbyak register name will always be
|
||||||
|
// printed.
|
||||||
// Example 1:
|
// Example 1:
|
||||||
// Invocation: RegPrinter::print<float>(*this, vmm_val, "vmm_val");
|
// Invocation: RegPrinter::print<float>(*this, vmm_val, "vmm_val");
|
||||||
// Console: vmm_val | ymm0: {30, 20, 25, 29, 24, 31, 27, 23}
|
// Console: vmm_val | ymm0: {30, 20, 25, 29, 24, 31, 27, 23}
|
||||||
@ -69,10 +72,9 @@ private:
|
|||||||
template <typename PRC_T, typename REG_T>
|
template <typename PRC_T, typename REG_T>
|
||||||
static void print_reg(jit_generator &h, REG_T reg, const char *name);
|
static void print_reg(jit_generator &h, REG_T reg, const char *name);
|
||||||
template <typename PRC_T, size_t vlen>
|
template <typename PRC_T, size_t vlen>
|
||||||
static void print_vmm_prc(const char *name, PRC_T *ptr);
|
static void print_vmm_prc(const char *name, const char *ori_name, PRC_T *ptr);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
static void print_reg_integer(const char *name, T val);
|
static void print_reg_prc(const char *name, const char *ori_name, T *val);
|
||||||
static void print_reg_fp32(const char *name, int val);
|
|
||||||
static void preamble(jit_generator &h);
|
static void preamble(jit_generator &h);
|
||||||
static void postamble(jit_generator &h);
|
static void postamble(jit_generator &h);
|
||||||
template <typename T>
|
template <typename T>
|
||||||
@ -81,8 +83,8 @@ private:
|
|||||||
static void restore_vmm(jit_generator &h);
|
static void restore_vmm(jit_generator &h);
|
||||||
static void save_reg(jit_generator &h);
|
static void save_reg(jit_generator &h);
|
||||||
static void restore_reg(jit_generator &h);
|
static void restore_reg(jit_generator &h);
|
||||||
template <typename REG_T>
|
static void align_rsp(jit_generator &h);
|
||||||
static const char * get_name(REG_T reg, const char *name);
|
static void restore_rsp(jit_generator &h);
|
||||||
static constexpr size_t reg_len = 8;
|
static constexpr size_t reg_len = 8;
|
||||||
static constexpr size_t reg_cnt = 16;
|
static constexpr size_t reg_cnt = 16;
|
||||||
};
|
};
|
||||||
|
Loading…
Reference in New Issue
Block a user