[Snippets] FP32 MHA postcommit fixes (#15180)

This commit is contained in:
Ivan Novoselov 2023-01-19 10:06:39 +00:00 committed by GitHub
parent d5f3bfa43e
commit ffcb83deba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 0 additions and 9 deletions

View File

@ -171,7 +171,6 @@ void LoopEnd::validate_and_infer_types() {
if (finalization_offsets.empty()) if (finalization_offsets.empty())
finalization_offsets.resize(loop_io_size, 0); finalization_offsets.resize(loop_io_size, 0);
set_output_size(num_inputs - 1); set_output_size(num_inputs - 1);
const auto& ins = inputs();
// All outputs are by-passed from inputs, except for the last one - it connects LoopBegin and LoopEnd // All outputs are by-passed from inputs, except for the last one - it connects LoopBegin and LoopEnd
for (int i = 0; i < num_inputs - 1; i++) for (int i = 0; i < num_inputs - 1; i++)
get_output_descriptor(i).set_tensor_ptr(get_input_descriptor(i).get_output().get_tensor_ptr()); get_output_descriptor(i).set_tensor_ptr(get_input_descriptor(i).get_output().get_tensor_ptr());

View File

@ -108,12 +108,6 @@ KernelEmitter::KernelEmitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl:
jcp = *reinterpret_cast<const jit_snippets_compile_args*>(kernel->compile_params); jcp = *reinterpret_cast<const jit_snippets_compile_args*>(kernel->compile_params);
// calc data access pattern. we'll need it for offsets calculation // calc data access pattern. we'll need it for offsets calculation
const auto& model = kernel->model; const auto& model = kernel->model;
const auto get_static_shape = [](const std::shared_ptr<ov::Node>& node) {
const auto& pshape = node->get_output_partial_shape(0);
if (pshape.is_dynamic())
IE_THROW() << "KernelEmitter can't calc offsets for dynamic shapes";
return pshape.get_shape();
};
const auto get_data_layout = [](const Output<ov::Node>& out, std::vector<size_t>& shape) { const auto get_data_layout = [](const Output<ov::Node>& out, std::vector<size_t>& shape) {
const auto& layout = ngraph::snippets::utils::get_node_output_layout(out.get_node_shared_ptr()); const auto& layout = ngraph::snippets::utils::get_node_output_layout(out.get_node_shared_ptr());
// default access pattern // default access pattern
@ -1061,7 +1055,6 @@ void HorizonMaxEmitter::emit_isa(const std::vector<size_t> &in, const std::vecto
Xmm aux_xmm = Xmm(aux_vec_idxs[0]); Xmm aux_xmm = Xmm(aux_vec_idxs[0]);
Reg64 aux_reg = Reg64(aux_gpr_idxs[0]); Reg64 aux_reg = Reg64(aux_gpr_idxs[0]);
Reg32 aux_reg_32 = Reg32(aux_reg.getIdx());
const size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen; const size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen;
const size_t vec_size = vlen / sizeof(float); const size_t vec_size = vlen / sizeof(float);
@ -1107,7 +1100,6 @@ void HorizonSumEmitter::emit_isa(const std::vector<size_t> &in, const std::vecto
Xmm aux_xmm = Xmm(aux_vec_idxs[0]); Xmm aux_xmm = Xmm(aux_vec_idxs[0]);
Reg64 aux_reg = Reg64(aux_gpr_idxs[0]); Reg64 aux_reg = Reg64(aux_gpr_idxs[0]);
Reg32 aux_reg_32 = Reg32(aux_reg.getIdx());
const size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen; const size_t vlen = dnnl::impl::cpu::x64::cpu_isa_traits<isa>::vlen;
const size_t vec_size = vlen / sizeof(float); const size_t vec_size = vlen / sizeof(float);