[Snippets] FP32 MHA postcommit fixes (#15180)
This commit is contained in:
parent
d5f3bfa43e
commit
ffcb83deba
@ -171,7 +171,6 @@ void LoopEnd::validate_and_infer_types() {
|
|||||||
if (finalization_offsets.empty())
|
if (finalization_offsets.empty())
|
||||||
finalization_offsets.resize(loop_io_size, 0);
|
finalization_offsets.resize(loop_io_size, 0);
|
||||||
set_output_size(num_inputs - 1);
|
set_output_size(num_inputs - 1);
|
||||||
const auto& ins = inputs();
|
|
||||||
// All outputs are by-passed from inputs, except for the last one - it connects LoopBegin and LoopEnd
|
// All outputs are by-passed from inputs, except for the last one - it connects LoopBegin and LoopEnd
|
||||||
for (int i = 0; i < num_inputs - 1; i++)
|
for (int i = 0; i < num_inputs - 1; i++)
|
||||||
get_output_descriptor(i).set_tensor_ptr(get_input_descriptor(i).get_output().get_tensor_ptr());
|
get_output_descriptor(i).set_tensor_ptr(get_input_descriptor(i).get_output().get_tensor_ptr());
|
||||||
|
@ -108,12 +108,6 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
|
|||||||
jcp = *reinterpret_cast<const jit_snippets_compile_args*>(kernel->compile_params);
|
jcp = *reinterpret_cast<const jit_snippets_compile_args*>(kernel->compile_params);
|
||||||
// calc data access pattern. we'll need it for offsets calculation
|
// calc data access pattern. we'll need it for offsets calculation
|
||||||
const auto& model = kernel->model;
|
const auto& model = kernel->model;
|
||||||
const auto get_static_shape = [](const std::shared_ptr<ov::Node>& node) {
|
|
||||||
const auto& pshape = node->get_output_partial_shape(0);
|
|
||||||
if (pshape.is_dynamic())
|
|
||||||
IE_THROW() << "KernelEmitter can't calc offsets for dynamic shapes";
|
|
||||||
return pshape.get_shape();
|
|
||||||
};
|
|
||||||
const auto get_data_layout = [](const Output<ov::Node>& out, std::vector<size_t>& shape) {
|
const auto get_data_layout = [](const Output<ov::Node>& out, std::vector<size_t>& shape) {
|
||||||
const auto& layout = ngraph::snippets::utils::get_node_output_layout(out.get_node_shared_ptr());
|
const auto& layout = ngraph::snippets::utils::get_node_output_layout(out.get_node_shared_ptr());
|
||||||
// default access pattern
|
// default access pattern
|
||||||
@ -1061,7 +1055,6 @@ void HorizonMaxEmitter::emit_isa(const std::vector<size_t> &in, const std::vecto
|
|||||||
Xmm aux_xmm = Xmm(aux_vec_idxs[0]);
|
Xmm aux_xmm = Xmm(aux_vec_idxs[0]);
|
||||||
|
|
||||||
Reg64 aux_reg = Reg64(aux_gpr_idxs[0]);
|
Reg64 aux_reg = Reg64(aux_gpr_idxs[0]);
|
||||||
Reg32 aux_reg_32 = Reg32(aux_reg.getIdx());
|
|
||||||
|
|
||||||
const size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen;
|
const size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen;
|
||||||
const size_t vec_size = vlen / sizeof(float);
|
const size_t vec_size = vlen / sizeof(float);
|
||||||
@ -1107,7 +1100,6 @@ void HorizonSumEmitter::emit_isa(const std::vector<size_t> &in, const std::vecto
|
|||||||
Xmm aux_xmm = Xmm(aux_vec_idxs[0]);
|
Xmm aux_xmm = Xmm(aux_vec_idxs[0]);
|
||||||
|
|
||||||
Reg64 aux_reg = Reg64(aux_gpr_idxs[0]);
|
Reg64 aux_reg = Reg64(aux_gpr_idxs[0]);
|
||||||
Reg32 aux_reg_32 = Reg32(aux_reg.getIdx());
|
|
||||||
|
|
||||||
const size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen;
|
const size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen;
|
||||||
const size_t vec_size = vlen / sizeof(float);
|
const size_t vec_size = vlen / sizeof(float);
|
||||||
|
Loading…
Reference in New Issue
Block a user