[CPU] Support MHA optimization (#12643)

* [CPU] Support MHA optimization

* [CPU] Extend pattern supported by MHA node

* [CPU] MHA: fixed int8 perf issue

Co-authored-by: Gu, Jianan <jianan.gu@intel.com>
This commit is contained in:
Gorokhov Dmitriy
2022-08-23 12:50:02 +04:00
committed by GitHub
parent b4d18bb406
commit a6bfc0cf0e
17 changed files with 2993 additions and 8 deletions

View File

@@ -197,6 +197,7 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
{ "Subgraph", Type::Subgraph},
{ "PriorBox", Type::PriorBox},
{ "PriorBoxClustered", Type::PriorBoxClustered},
{ "MHA", Type::MHA},
};
Type TypeFromName(const std::string& type) {
@@ -388,6 +389,8 @@ std::string NameFromType(const Type type) {
return "Reference";
case Type::Subgraph:
return "Subgraph";
case Type::MHA:
return "MHA";
default:
return "Unknown";
}

View File

@@ -108,6 +108,7 @@ enum class Type {
Subgraph,
PriorBox,
PriorBoxClustered,
MHA
};
enum class Algorithm {

View File

@@ -144,9 +144,6 @@ void jit_emitter::emitter_preamble(const std::vector<size_t> &in_idxs, const std
for (size_t i = 0; i < preserved_vec_idxs.size(); ++i) {
push_vec(h->ptr[h->rsp + i * get_vec_length()], preserved_vec_idxs[i]);
}
if (!entry_map_.empty())
load_table_addr();
}
@@ -204,6 +201,9 @@ void jit_emitter::emit_code(const std::vector<size_t> &in_idxs, const std::vecto
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) const {
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);
if (!entry_map_.empty())
load_table_addr();
emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, nullptr);
emitter_postamble();
@@ -214,6 +214,9 @@ void jit_emitter::emit_code(const std::vector<size_t> &in_idxs, const std::vecto
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);
if (!entry_map_.empty())
load_table_addr();
emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, emit_context.get());
emitter_postamble();

View File

@@ -516,6 +516,34 @@ void jit_load_emitter::register_table_entries() {
push_arg_entry_of("float_max", 0x7f7fffff, true);
}
void jit_load_emitter::emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) const {
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);
emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, nullptr);
emitter_postamble();
}
void jit_load_emitter::emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::shared_ptr<const emitter_context> &emit_context,
const std::vector<size_t> &pool_vec_idxs, const std::vector<size_t> &pool_gpr_idxs) {
emitter_preamble(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs);
const auto* load_emitter_context = dynamic_cast<const ov::intel_cpu::load_emitter_context*>(emit_context.get());
if (load_emitter_context == nullptr) {
IE_THROW() << "Load emitter in " << name << " does not get load emmiter context.";
}
if (!entry_map_.empty() && load_emitter_context->is_fill_)
load_table_addr();
emit_impl(in_idxs, out_idxs, pool_vec_idxs, pool_gpr_idxs, emit_context.get());
emitter_postamble();
}
/// STORE ///
jit_store_emitter::jit_store_emitter(jit_generator *host, cpu_isa_t host_isa,
Precision exec_prc, emitter_in_out_map in_out_type)

View File

@@ -47,6 +47,14 @@ class jit_load_emitter : public jit_emitter {
public:
jit_load_emitter(dnnl::impl::cpu::x64::jit_generator *host, dnnl::impl::cpu::x64::cpu_isa_t host_isa,
InferenceEngine::Precision exec_prc = InferenceEngine::Precision::FP32, emitter_in_out_map in_out_type = emitter_in_out_map::gpr_to_vec);
void emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::vector<size_t> &pool_vec_idxs = {}, const std::vector<size_t> &pool_gpr_idxs = {}) const override;
void emit_code(const std::vector<size_t> &in_idxs, const std::vector<size_t> &out_idxs,
const std::shared_ptr<const emitter_context> &emit_context,
const std::vector<size_t> &pool_vec_idxs = {}, const std::vector<size_t> &pool_gpr_idxs = {}) override;
/**
* load_num values with src_prc precision are loaded from ptr[Reg64(in_idxs[0]) + offset_byte] address to Vmm[out_idxs[0]] as dst_prc.
* is_fill: when load_num can not fully fit in vector register, whether fill_value should be filled as default values.

View File

@@ -7,6 +7,7 @@
#include "ngraph_transformations/op/leaky_relu.hpp"
#include "ngraph_transformations/op/power_static.hpp"
#include "ngraph_transformations/op/swish_cpu.hpp"
#include "ngraph_transformations/op/mha.hpp"
#include <ngraph/ngraph.hpp>
#include <ngraph_ops/type_relaxed.hpp>
@@ -40,6 +41,7 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
NGRAPH_OP(LeakyReluNode, ov::intel_cpu)
NGRAPH_OP(PowerStaticNode, ov::intel_cpu)
NGRAPH_OP(SwishNode, ov::intel_cpu)
NGRAPH_OP(MHANode, ov::intel_cpu)
#undef NGRAPH_OP
return opset;

View File

@@ -0,0 +1,559 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "mha_fusion.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "op/mha.hpp"
#include "itt.hpp"
// TODO: draw pattern
ov::intel_cpu::MHAFusion::MHAFusion() {
MATCHER_SCOPE(MHAFusion);
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
auto mul = std::make_shared<ngraph::opset3::Multiply>(transpose1, in2);
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, mul);
auto add = std::make_shared<ngraph::opset4::Add>(matmul0, in3);
auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, in6, true);
auto softmax = std::make_shared<ngraph::opset1::Softmax>(reshape0);
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softmax, in7, true);
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(reshape1, transpose2);
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(matmul1, in10);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto transpose0_in = pattern_to_output.at(in0);
auto transpose1_in = pattern_to_output.at(in1);
auto mul_in1 = pattern_to_output.at(in2);
auto add_in1 = pattern_to_output.at(in3);
auto transpose2_in = pattern_to_output.at(in8);
// TODO: check transpose order
auto const_node = std::dynamic_pointer_cast<ngraph::opset3::Constant>(mul_in1.get_node_shared_ptr());
if (!const_node)
return false;
std::vector<float> mul_scales;
if (auto mul_node = std::dynamic_pointer_cast<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
mul_scales = std::dynamic_pointer_cast<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
} else {
return false;
}
auto matmul0_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
if (!matmul0_node)
return false;
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
return false;
auto reshape0_node = std::dynamic_pointer_cast<ngraph::opset1::Reshape>(pattern_to_output.at(reshape0).get_node_shared_ptr());
if (!reshape0_node)
return false;
if (auto reshape0_pattern = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(in6).get_node_shared_ptr())) {
// TODO: add valid condition based on reshape pattern
auto reshape0_pattern_values = reshape0_pattern->cast_vector<int64_t>();
} else {
return false;
}
auto reshape1_node = std::dynamic_pointer_cast<ngraph::opset1::Reshape>(pattern_to_output.at(reshape1).get_node_shared_ptr());
if (!reshape1_node)
return false;
if (auto reshape1_pattern = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(in7).get_node_shared_ptr())) {
// TODO: add valid condition based on reshape pattern
auto reshape1_pattern_values = reshape1_pattern->cast_vector<int64_t>();
} else {
return false;
}
if (reshape0_node->get_output_partial_shape(0).rank() != 2 || reshape0_node->get_input_partial_shape(0) != reshape1_node->get_output_partial_shape(0))
return false;
auto softmax_node = std::dynamic_pointer_cast<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
if (!softmax_node)
return false;
if (softmax_node->get_axis() != 1)
return false;
auto matmul1_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
if (!matmul1_node)
return false;
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
return false;
bool is_mul_first = true;
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
transpose3_node->output(0).get_element_type());
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
pattern_to_output.at(transpose1).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(matmul0).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(reshape0).get_node_shared_ptr(),
pattern_to_output.at(softmax).get_node_shared_ptr(),
pattern_to_output.at(reshape1).get_node_shared_ptr(),
pattern_to_output.at(transpose2).get_node_shared_ptr(),
pattern_to_output.at(matmul1).get_node_shared_ptr(),
pattern_to_output.at(transpose3).get_node_shared_ptr(),
},
mha);
if (transformation_callback(mha)) {
return false;
}
ngraph::replace_node(m.get_match_root(), mha);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
this->register_matcher(m, callback);
}
ov::intel_cpu::MHAFusion2::MHAFusion2() {
MATCHER_SCOPE(MHAFusion2);
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1);
auto add = std::make_shared<ngraph::opset4::Add>(matmul0, in3);
auto softmax = std::make_shared<ngraph::opset1::Softmax>(add);
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, transpose2);
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(matmul1, in10);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto transpose0_in = pattern_to_output.at(in0);
auto transpose1_in = pattern_to_output.at(in1);
auto add_in1 = pattern_to_output.at(in3);
auto transpose2_in = pattern_to_output.at(in8);
// TODO: check transpose order
auto matmul0_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
if (!matmul0_node)
return false;
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
return false;
auto softmax_node = std::dynamic_pointer_cast<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
if (!softmax_node)
return false;
if (softmax_node->get_axis() != 3)
return false;
auto matmul1_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
if (!matmul1_node)
return false;
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
return false;
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, std::vector<float>(), false,
transpose3_node->output(0).get_element_type());
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
pattern_to_output.at(transpose1).get_node_shared_ptr(),
pattern_to_output.at(matmul0).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(softmax).get_node_shared_ptr(),
pattern_to_output.at(transpose2).get_node_shared_ptr(),
pattern_to_output.at(matmul1).get_node_shared_ptr(),
pattern_to_output.at(transpose3).get_node_shared_ptr(),
},
mha);
if (transformation_callback(mha)) {
return false;
}
ngraph::replace_node(m.get_match_root(), mha);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
this->register_matcher(m, callback);
}
static std::vector<float> simplifyToScale(const std::shared_ptr<ngraph::opset1::FakeQuantize>& fq_node) {
auto levels = fq_node->get_levels();
auto input_low = std::dynamic_pointer_cast<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(1))->cast_vector<float>();
auto input_high = std::dynamic_pointer_cast<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(2))->cast_vector<float>();
auto output_low = std::dynamic_pointer_cast<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(3))->cast_vector<float>();
auto output_high = std::dynamic_pointer_cast<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(4))->cast_vector<float>();
std::vector<float> cl, ch, isc, ish, osc, osh;
for (int i = 0; i < input_low.size(); i++) {
cl.push_back(input_low[i]);
}
for (int i = 0; i < input_high.size(); i++) {
ch.push_back(input_high[i]);
}
for (int i = 0; i < std::max(input_low.size(), input_high.size()); i++) {
float il = input_low[input_low.size() == 1 ? 0 : i];
float ih = input_high[input_high.size() == 1 ? 0 : i];
isc.push_back((levels - 1) / (ih - il));
ish.push_back(-il * (levels - 1) / (ih - il));
}
for (int i = 0; i < std::max(output_low.size(), output_high.size()); i++) {
float ol = output_low[output_low.size() == 1 ? 0 : i];
float oh = output_high[output_high.size() == 1 ? 0 : i];
osc.push_back((oh - ol) / (levels - 1));
osh.push_back(ol);
}
std::vector<float> outScale;
if (fq_node->get_output_element_type(0) == ngraph::element::u8 &&
std::all_of(cl.cbegin(), cl.cend(), [](float val) { return val == 0.0f; }) &&
std::all_of(ish.cbegin(), ish.cend(), [](float val) { return val == 0.0f; }) &&
std::all_of(osc.cbegin(), osc.cend(), [](float val) { return val == 1.0f; }) &&
std::all_of(osh.cbegin(), osh.cend(), [](float val) { return val == 0.0f; })) {
outScale = isc;
}
if (fq_node->get_output_element_type(0) == ngraph::element::i8 &&
std::all_of(ish.cbegin(), ish.cend(), [](float val) { return std::abs(val - 128.f) < 0.0001f; }) &&
std::all_of(osc.cbegin(), osc.cend(), [](float val) { return val == 1.f; }) &&
std::all_of(osh.cbegin(), osh.cend(), [](float val) { return std::abs(val + 128.f) < 0.0001f; })) {
bool isCropAligned = true;
for (int i = 0; i < std::max(cl.size(), isc.size()); i++) {
if (std::abs(cl[cl.size() == 1 ? 0 : i] * isc[isc.size() == 1 ? 0 : i] + 128.f) > 0.0001f) {
isCropAligned = false;
}
}
for (int i = 0; i < std::max(ch.size(), isc.size()); i++) {
if (std::abs(ch[ch.size() == 1 ? 0 : i] * isc[isc.size() == 1 ? 0 : i] - 127.f) > 0.0001f) {
isCropAligned = false;
}
}
if (isCropAligned) {
outScale = isc;
}
}
return outScale;
}
// TODO: draw pattern
ov::intel_cpu::MHAQuantFusion::MHAQuantFusion() {
MATCHER_SCOPE(MHAQuantFusion);
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1);
auto fakeQuantize0 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul0,
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
auto add = std::make_shared<ngraph::opset4::Add>(fakeQuantize0, in3);
auto mul = std::make_shared<ngraph::opset3::Multiply>(add, in2);
auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(mul, in6, true);
auto softmax = std::make_shared<ngraph::opset1::Softmax>(reshape0);
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softmax, in7, true);
auto fakeQuantize1 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({reshape1,
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(fakeQuantize1, transpose2);
auto fakeQuantize2 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul1,
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(fakeQuantize2, in10);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto transpose0_in = pattern_to_output.at(in0);
auto transpose1_in = pattern_to_output.at(in1);
auto add_in1 = pattern_to_output.at(in3);
auto transpose2_in = pattern_to_output.at(in8);
// todo: check mul shape
std::vector<float> mul_scales;
if (auto mul_node = std::dynamic_pointer_cast<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
mul_scales = std::dynamic_pointer_cast<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
} else {
return false;
}
// TODO: check transpose order
auto matmul0_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
if (!matmul0_node)
return false;
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
return false;
std::vector<float> fq0_scale;
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize0).get_node_shared_ptr())) {
fq0_scale = simplifyToScale(fq_node);
if (!fq0_scale.size())
return false;
}
auto reshape0_node = std::dynamic_pointer_cast<ngraph::opset1::Reshape>(pattern_to_output.at(reshape0).get_node_shared_ptr());
if (!reshape0_node)
return false;
if (auto reshape0_pattern = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(in6).get_node_shared_ptr())) {
// TODO: add valid condition based on reshape pattern
auto reshape0_pattern_values = reshape0_pattern->cast_vector<int64_t>();
} else {
return false;
}
auto reshape1_node = std::dynamic_pointer_cast<ngraph::opset1::Reshape>(pattern_to_output.at(reshape1).get_node_shared_ptr());
if (!reshape1_node)
return false;
if (auto reshape1_pattern = std::dynamic_pointer_cast<ngraph::opset4::Constant>(pattern_to_output.at(in7).get_node_shared_ptr())) {
// TODO: add valid condition based on reshape pattern
auto reshape1_pattern_values = reshape1_pattern->cast_vector<int64_t>();
} else {
return false;
}
if (reshape0_node->get_output_partial_shape(0).rank() != 2 || reshape0_node->get_input_partial_shape(0) != reshape1_node->get_output_partial_shape(0))
return false;
auto softmax_node = std::dynamic_pointer_cast<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
if (!softmax_node)
return false;
if (softmax_node->get_axis() != 1)
return false;
std::vector<float> fq1_scale;
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
fq1_scale = simplifyToScale(fq_node);
if (!fq1_scale.size())
return false;
}
auto matmul1_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
if (!matmul1_node)
return false;
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
return false;
std::vector<float> fq2_scale;
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize2).get_node_shared_ptr())) {
fq2_scale = simplifyToScale(fq_node);
if (!fq2_scale.size())
return false;
}
bool is_mul_first = false;
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
std::vector<float>(), fq0_scale, fq1_scale, fq2_scale,
transpose3_node->output(0).get_element_type());
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
pattern_to_output.at(transpose1).get_node_shared_ptr(),
pattern_to_output.at(matmul0).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(reshape0).get_node_shared_ptr(),
pattern_to_output.at(softmax).get_node_shared_ptr(),
pattern_to_output.at(reshape1).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(),
pattern_to_output.at(transpose2).get_node_shared_ptr(),
pattern_to_output.at(matmul1).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize2).get_node_shared_ptr(),
pattern_to_output.at(transpose3).get_node_shared_ptr(),
},
mha);
if (transformation_callback(mha)) {
return false;
}
ngraph::replace_node(m.get_match_root(), mha);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
this->register_matcher(m, callback);
}
// TODO: draw pattern
ov::intel_cpu::MHAQuantFusion2::MHAQuantFusion2() {
MATCHER_SCOPE(MHAQuantFusion2);
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
auto fakeQuantize0 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({transpose1,
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, fakeQuantize0);
auto mul = std::make_shared<ngraph::opset3::Multiply>(matmul0, in2);
auto add = std::make_shared<ngraph::opset4::Add>(mul, in3);
auto softmax = std::make_shared<ngraph::opset1::Softmax>(add);
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, transpose2);
auto fakeQuantize1 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul1,
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(fakeQuantize1, in10);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto transpose0_in = pattern_to_output.at(in0);
auto transpose1_in = pattern_to_output.at(in1);
auto add_in1 = pattern_to_output.at(in3);
auto transpose2_in = pattern_to_output.at(in8);
// todo: check mul shape
std::vector<float> mul_scales;
if (auto mul_node = std::dynamic_pointer_cast<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
mul_scales = std::dynamic_pointer_cast<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
} else {
return false;
}
// TODO: check transpose order
auto matmul0_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
if (!matmul0_node)
return false;
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
return false;
std::vector<float> fq0_scale;
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize0).get_node_shared_ptr())) {
fq0_scale = simplifyToScale(fq_node);
if (!fq0_scale.size())
return false;
}
auto softmax_node = std::dynamic_pointer_cast<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
if (!softmax_node)
return false;
if (softmax_node->get_axis() != 3)
return false;
std::vector<float> fq1_scale;
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
fq1_scale = simplifyToScale(fq_node);
if (!fq1_scale.size())
return false;
}
auto matmul1_node = std::dynamic_pointer_cast<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
if (!matmul1_node)
return false;
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
return false;
std::vector<float> fq2_scale;
if (auto fq_node = std::dynamic_pointer_cast<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
fq2_scale = simplifyToScale(fq_node);
if (!fq2_scale.size())
return false;
}
bool is_mul_first = true;
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
fq0_scale, std::vector<float>(), std::vector<float>(), fq1_scale,
transpose3_node->output(0).get_element_type());
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
pattern_to_output.at(transpose1).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
pattern_to_output.at(matmul0).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(softmax).get_node_shared_ptr(),
pattern_to_output.at(transpose2).get_node_shared_ptr(),
pattern_to_output.at(matmul1).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(),
pattern_to_output.at(transpose3).get_node_shared_ptr(),
},
mha);
if (transformation_callback(mha)) {
return false;
}
ngraph::replace_node(m.get_match_root(), mha);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
this->register_matcher(m, callback);
}

View File

@@ -0,0 +1,37 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/pass/graph_rewrite.hpp>
namespace ov {
namespace intel_cpu {
class MHAFusion: public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("MHAFusion", "0");
MHAFusion();
};
class MHAFusion2: public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("MHAFusion2", "0");
MHAFusion2();
};
class MHAQuantFusion: public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("MHAQuantFusion", "0");
MHAQuantFusion();
};
class MHAQuantFusion2: public ngraph::pass::MatcherPass {
public:
OPENVINO_RTTI("MHAQuantFusion2", "0");
MHAQuantFusion2();
};
} // namespace intel_cpu
} // namespace ov

View File

@@ -0,0 +1,98 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "mha.hpp"
#include "../itt.hpp"
#include <ngraph/opsets/opset3.hpp>
#include <matmul_shape_inference.hpp>
ov::intel_cpu::MHANode::MHANode(const ngraph::Output<ngraph::Node> &in0,
const ngraph::Output<ngraph::Node> &in1,
const ngraph::Output<ngraph::Node> &in2,
const ngraph::Output<ngraph::Node> &in3,
const std::vector<float> &mul_scales,
bool is_mul_first,
const ngraph::element::Type output_type)
: Op({in0, in1, in2, in3}), m_output_type(output_type) {
this->mul_scales = mul_scales;
this->is_mul_first = is_mul_first;
validate_and_infer_types();
}
ov::intel_cpu::MHANode::MHANode(const ngraph::Output<ngraph::Node> &in0,
const ngraph::Output<ngraph::Node> &in1,
const ngraph::Output<ngraph::Node> &in2,
const ngraph::Output<ngraph::Node> &in3,
const std::vector<float> &mul_scales,
bool is_mul_first,
const std::vector<float> &fq_scales0,
const std::vector<float> &fq_scales1,
const std::vector<float> &fq_scales2,
const std::vector<float> &fq_scales3,
const ngraph::element::Type output_type)
: Op({in0, in1, in2, in3}), m_output_type(output_type) {
this->mul_scales = mul_scales;
this->is_mul_first = is_mul_first;
this->fq_scales0 = fq_scales0;
this->fq_scales1 = fq_scales1;
this->fq_scales2 = fq_scales2;
this->fq_scales3 = fq_scales3;
validate_and_infer_types();
}
std::shared_ptr<ngraph::Node> ov::intel_cpu::MHANode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
INTERNAL_OP_SCOPE(MHANode_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<ov::intel_cpu::MHANode>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
mul_scales, is_mul_first, fq_scales0, fq_scales1, fq_scales2, fq_scales3, m_output_type);
}
void ov::intel_cpu::MHANode::validate_and_infer_types() {
INTERNAL_OP_SCOPE(MHANode_validate_and_infer_types);
auto transpose = [](const ov::Shape& shape, const std::vector<size_t>& order) -> ov::Shape {
std::vector<size_t> new_shape(shape.size());
for (int i = 0; i < shape.size(); i++) {
new_shape[i] = shape[order[i]];
}
return new_shape;
};
const auto matmul0_shape0 = transpose(get_input_partial_shape(0).get_shape(), {0, 2, 1, 3});
const auto matmul0_shape1 = transpose(get_input_partial_shape(1).get_shape(), {0, 2, 3, 1});
auto matmul0_in0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul0_shape0);
auto matmul0_in1 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul0_shape1);
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(matmul0_in0, matmul0_in1);
std::vector<ov::PartialShape> matmul0_input_shapes = {matmul0_shape0, matmul0_shape1};
std::vector<ov::PartialShape> matmul0_output_shapes = {ov::PartialShape{}};
shape_infer(matmul0.get(), matmul0_input_shapes, matmul0_output_shapes);
const auto matmul1_shape0 = matmul0_output_shapes[0];
const auto matmul1_shape1 = transpose(get_input_partial_shape(3).get_shape(), {0, 2, 1, 3});
auto matmul1_in0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul1_shape0);
auto matmul1_in1 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul1_shape1);
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(matmul1_in0, matmul1_in1);
std::vector<ov::PartialShape> matmul1_input_shapes = {matmul1_shape0, matmul1_shape1};
std::vector<ov::PartialShape> matmul1_output_shapes = {ov::PartialShape{}};
shape_infer(matmul1.get(), matmul1_input_shapes, matmul1_output_shapes);
const auto output_shape = transpose(matmul1_output_shapes[0].get_shape(), {0, 2, 1, 3});
set_output_type(
0,
m_output_type == ngraph::element::undefined || m_output_type == ngraph::element::dynamic ? get_input_element_type(0) : m_output_type,
output_shape);
}
bool ov::intel_cpu::MHANode::visit_attributes(ngraph::AttributeVisitor &visitor) {
INTERNAL_OP_SCOPE(MHANode_visit_attributes);
visitor.on_attribute("out-type", m_output_type);
return true;
}

View File

@@ -0,0 +1,78 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/op/op.hpp>
namespace ov {
namespace intel_cpu {
class MHANode : public ngraph::op::Op {
public:
OPENVINO_OP("MHA", "cpu_plugin_opset");
MHANode() = default;
MHANode(const ngraph::Output<ngraph::Node> &in0,
const ngraph::Output<ngraph::Node> &in1,
const ngraph::Output<ngraph::Node> &in2,
const ngraph::Output<ngraph::Node> &in3,
const std::vector<float> &mul_scales,
bool is_mul_first,
const ngraph::element::Type output_type);
MHANode(const ngraph::Output<ngraph::Node> &in0,
const ngraph::Output<ngraph::Node> &in1,
const ngraph::Output<ngraph::Node> &in2,
const ngraph::Output<ngraph::Node> &in3,
const std::vector<float> &mul_scales,
bool is_mul_first,
const std::vector<float> &fq_scales0,
const std::vector<float> &fq_scales1,
const std::vector<float> &fq_scales2,
const std::vector<float> &fq_scales3,
const ngraph::element::Type output_type);
void validate_and_infer_types() override;
bool visit_attributes(ngraph::AttributeVisitor &visitor) override;
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector &new_args) const override;
ngraph::element::Type get_output_type() const { return m_output_type; }
const std::vector<float>& get_mul_scales() const {
return mul_scales;
}
const std::vector<float>& get_fq_scales0() const {
return fq_scales0;
}
const std::vector<float>& get_fq_scales1() const {
return fq_scales1;
}
const std::vector<float>& get_fq_scales2() const {
return fq_scales2;
}
const std::vector<float>& get_fq_scales3() const {
return fq_scales3;
}
bool get_is_mul_first() const {
return is_mul_first;
}
private:
ngraph::element::Type m_output_type;
std::vector<float> mul_scales;
bool is_mul_first;
std::vector<float> fq_scales0;
std::vector<float> fq_scales1;
std::vector<float> fq_scales2;
std::vector<float> fq_scales3;
};
} // namespace intel_cpu
} // namespace ov

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,242 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <node.h>
#include <memory>
#include <string>
#include <vector>
#include <cpu/x64/brgemm/brgemm.hpp>
#include <cpu/x64/matmul/brgemm_matmul_copy_utils.hpp>
#include <cpu/x64/matmul/brgemm_matmul_utils.hpp>
#include <cpu/x64/amx_tile_configure.hpp>
namespace ov {
namespace intel_cpu {
namespace node {
struct jit_mul_add_softmax_compile_params {
InferenceEngine::Precision src_prc;
InferenceEngine::Precision dst_prc;
size_t work_amount;
bool with_mul_scales;
bool is_mul_first;
bool with_scales0;
bool broadcast_scales0;
bool with_scales1;
bool broadcast_scales1;
};
struct jit_mul_add_softmax_call_args {
const void *p_in0;
const void *p_mul_in1;
const void *p_add_in1;
void *p_out;
void *p_buffer;
const void *p_scales0;
const void *p_scales1;
};
struct jit_uni_mul_add_softmax_kernel {
void (*ker_)(const jit_mul_add_softmax_call_args*);
void operator()(const jit_mul_add_softmax_call_args* call_args) {
assert(ker_);
ker_(call_args);
}
explicit jit_uni_mul_add_softmax_kernel(const jit_mul_add_softmax_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
virtual ~jit_uni_mul_add_softmax_kernel() {}
virtual void create_ker() = 0;
jit_mul_add_softmax_compile_params jcp_;
};
struct jit_convert_reorder_compile_params {
InferenceEngine::Precision src_prc;
InferenceEngine::Precision dst_prc;
size_t inner_work_amount;
bool with_scales;
bool broadcast_scales;
size_t src_stride;
size_t dst_stride;
};
struct jit_convert_reorder_call_args {
const void *p_in;
void *p_out;
const void *p_scales;
size_t outter_work_amount;
};
struct jit_uni_convert_reorder_kernel {
void (*ker_)(const jit_convert_reorder_call_args*);
void operator()(const jit_convert_reorder_call_args* call_args) {
assert(ker_);
ker_(call_args);
}
explicit jit_uni_convert_reorder_kernel(const jit_convert_reorder_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
virtual ~jit_uni_convert_reorder_kernel() {}
virtual void create_ker() = 0;
jit_convert_reorder_compile_params jcp_;
};
struct jit_convert_transpose_compile_params {
InferenceEngine::Precision src_prc;
InferenceEngine::Precision dst_prc;
size_t inner_work_amount;
size_t outter_work_amount;
bool with_scales;
bool broadcast_scales;
size_t inner_src_stride;
size_t outter_src_stride;
size_t outter_dst_stride;
};
struct jit_convert_transpose_call_args {
const void *p_in;
void *p_out;
const void *p_scales;
};
struct jit_uni_convert_transpose_kernel {
void (*ker_)(const jit_convert_transpose_call_args*);
void operator()(const jit_convert_transpose_call_args* call_args) {
assert(ker_);
ker_(call_args);
}
explicit jit_uni_convert_transpose_kernel(const jit_convert_transpose_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
virtual ~jit_uni_convert_transpose_kernel() {}
virtual void create_ker() = 0;
jit_convert_transpose_compile_params jcp_;
};
#define MHA_BRGEMM_KERNELS_NUM 8
class MHA : public Node {
public:
MHA(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache);
void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override;
void execute(dnnl::stream strm) override;
bool created() const override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
protected:
void executeDynamicImpl(dnnl::stream strm) override;
void prepareParams() override;
private:
struct brgemmCtx {
size_t M, N, K, LDA, LDB, LDC;
dnnl_data_type_t dt_in0, dt_in1;
char palette[64];
bool is_with_amx;
bool is_with_comp;
float beta;
};
template <typename in1_type>
void mhaImpl();
void init_brgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel, bool use_amx);
void init_brgemm_copy_a(std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_a_t>& brgCopyKernel,
size_t K, size_t K_blk, size_t K_tail, size_t LDA, dnnl_data_type_t dt_in0);
void init_brgemm_copy_b(std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t>& brgCopyKernel,
size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1);
void callBrgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel,
const void* pin0, const void* pin1, void* pout, void* wsp);
size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) {
return mIdx * 4 + kIdx * 2 + nIdx;
}
std::vector<InferenceEngine::Precision> inputPrecisions;
InferenceEngine::Precision accPrecision0;
InferenceEngine::Precision accPrecision1;
VectorDims dimsTranspose0In0;
VectorDims dimsTranspose1In0;
VectorDims dimsMulIn1;
VectorDims dimsAddIn1;
VectorDims dimsTranspose2In0;
VectorDims dimsOut;
VectorDims strTranspose0In0;
VectorDims strTranspose1In0;
VectorDims strMulIn1;
VectorDims strAddIn1;
VectorDims strTranspose2In0;
VectorDims strOut;
VectorDims dimsMatMul0In0;
VectorDims dimsMatMul0In1;
VectorDims dimsMatMul0Out;
VectorDims dimsMatMul1In1;
VectorDims dimsMatMul1Out;
size_t batch0, batch1;
size_t M, M_blk, M_tail;
size_t K0, K0_blk, K0_tail, N0, N0_blk, N0_tail;
size_t K1, K1_blk, K1_tail, N1, N1_blk, N1_tail;
size_t bufferMatMul0In0Size;
size_t bufferMatMul0In1Size;
size_t bufferMatMul0OutSize;
size_t bufferMatMul1In1Size;
size_t bufferMatMul1OutSize;
size_t bufferCompensation0Size;
size_t bufferCompensation1Size;
size_t wsp_size_per_thread = 4 * 1024;
std::vector<uint8_t> bufferMatMul0In0;
std::vector<uint8_t> bufferMatMul0In1;
std::vector<uint8_t> bufferMatMul0Out;
std::vector<uint8_t> bufferMatMul1In1;
std::vector<uint8_t> bufferMatMul1Out;
std::vector<int32_t> bufferCompensation0;
std::vector<int32_t> bufferCompensation1;
std::vector<size_t> wsp;
bool isMulFirst;
std::vector<float> mulScales;
std::vector<float> fqScales0;
std::vector<float> fqScales1;
std::vector<float> fqScales2;
std::vector<float> fqScales3;
size_t brg0VnniFactor;
brgemmCtx brgCtxs0[MHA_BRGEMM_KERNELS_NUM];
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> brgKernels0[MHA_BRGEMM_KERNELS_NUM];
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_a_t> brgCopyAKernel0;
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t> brgCopyBKernel0;
size_t brg1VnniFactor;
brgemmCtx brgCtxs1[MHA_BRGEMM_KERNELS_NUM];
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> brgKernels1[MHA_BRGEMM_KERNELS_NUM];
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t> brgCopyBKernel1;
std::unique_ptr<jit_uni_mul_add_softmax_kernel> mulAddSoftmaxKernel;
std::unique_ptr<jit_uni_convert_reorder_kernel> convertReorderKernel;
std::unique_ptr<jit_uni_convert_transpose_kernel> convertTransposeKernel;
};
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@@ -88,6 +88,7 @@
#include "nodes/priorbox.h"
#include "nodes/priorbox_clustered.h"
#include "nodes/eye.h"
#include "nodes/mha.h"
namespace ov {
namespace intel_cpu {
@@ -188,6 +189,7 @@ Node::NodesFactory::NodesFactory()
INTEL_CPU_NODE(PriorBox, Type::PriorBox);
INTEL_CPU_NODE(PriorBoxClustered, Type::PriorBoxClustered);
INTEL_CPU_NODE(Eye, Type::Eye);
INTEL_CPU_NODE(MHA, Type::MHA);
}
#undef INTEL_CPU_NODE

View File

@@ -86,6 +86,7 @@
#include <transformations/op_conversions/convert_roi_align_v9_to_v3.hpp>
#include <transformations/op_conversions/convert_roi_align_v3_to_v9.hpp>
#include <transformations/op_conversions/softsign_decomposition.hpp>
#include "ngraph_transformations/mha_fusion.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset2.hpp>
@@ -117,6 +118,7 @@
#include "nodes/mvn.h"
#include "nodes/fake_quantize.h"
#include "nodes/normalize.h"
#include "nodes/mha.h"
#include "ngraph_transformations/convert_to_cpu_specific_opset.hpp"
#include "ngraph_transformations/move_eltwise_up_data_movement.hpp"
#include "transformations/smart_reshape/smart_reshape.hpp"
@@ -248,7 +250,7 @@ Engine::~Engine() {
executorManager()->clear("CPUCallbackExecutor");
}
static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function> nGraphFunc, const bool _enableLPT,
static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function> nGraphFunc, const bool _enableLPT, const bool _enableBF16,
const bool _enableSnippets, const bool isLegacyApi) {
ngraph::pass::Manager manager;
manager.set_per_pass_validation(false);
@@ -598,6 +600,34 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
});
postLPTPassManager.register_pass<ngraph::pass::ConstantFolding>();
auto isMHASupported = [_enableBF16](const std::shared_ptr<const ov::Node>& n) -> bool {
std::string errorMessage;
if (!node::MHA::isSupportedOperation(n, errorMessage))
return true;
// Implementation calls AMX BF16 brgemm only for tensors with K and N aligned on 2, otherwise fallbacks on vector impl
// Vector madd BF16 instruction on SPR has reduced performance on HW level, which results in overall perf degradation
size_t bf16Factor = 2;
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16_amx_bf16) &&
(n->get_input_element_type(0) == element::bf16 || (n->get_input_element_type(0) == element::f32 && _enableBF16)) &&
(n->get_input_shape(0)[3] % bf16Factor != 0 || n->get_input_shape(1)[1] % bf16Factor != 0 || n->get_input_shape(3)[3] % bf16Factor != 0)) {
return true;
}
return false;
};
// Snippets may brake MHA patterns so the fusion has to performed before
postLPTPassManager.register_pass<MHAFusion>();
postLPTPassManager.get_pass_config()->set_callback<MHAFusion>(isMHASupported);
postLPTPassManager.register_pass<MHAFusion2>();
postLPTPassManager.get_pass_config()->set_callback<MHAFusion2>(isMHASupported);
postLPTPassManager.register_pass<MHAQuantFusion>();
postLPTPassManager.get_pass_config()->set_callback<MHAQuantFusion>(isMHASupported);
postLPTPassManager.register_pass<MHAQuantFusion2>();
postLPTPassManager.get_pass_config()->set_callback<MHAQuantFusion2>(isMHASupported);
postLPTPassManager.run_passes(nGraphFunc);
if (!useLpt && _enableSnippets && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) {
@@ -629,9 +659,9 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
}
}
static void Transformation(CNNNetwork& clonedNetwork, const bool _enableLPT, const bool _enableSnippets, const bool isLegacyApi) {
static void Transformation(CNNNetwork& clonedNetwork, const bool _enableLPT, const bool _enableBF16, const bool _enableSnippets, const bool isLegacyApi) {
auto nGraphFunc = clonedNetwork.getFunction();
TransformationUpToCPUSpecificOpSet(nGraphFunc, _enableLPT, _enableSnippets, isLegacyApi);
TransformationUpToCPUSpecificOpSet(nGraphFunc, _enableLPT, _enableBF16, _enableSnippets, isLegacyApi);
ConvertToCPUSpecificOpset(nGraphFunc);
}
@@ -772,7 +802,7 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std
|| engConfig.enableDynamicBatch;
const bool enableSnippets = !(enableModelCache || enableDynamicBatch || enableBF16);
auto nGraphFunc = clonedNetwork.getFunction();
TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableSnippets, isLegacyAPI());
TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableBF16, enableSnippets, isLegacyAPI());
// need to check that all outputs have static shapes
// checking that all inputs have static shapes is performed in the common part
@@ -1020,7 +1050,7 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork& network, const std::ma
|| Config::LPTransformsMode::On == engConfig.lpTransformsMode /* or already enabled */;
const bool enableSnippets = !(conf.cache_dir.empty() || conf.enableDynamicBatch || (conf.enforceBF16
&& dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)));
Transformation(clonedNetwork, enableLPT, enableSnippets, isLegacyAPI());
Transformation(clonedNetwork, enableLPT, conf.enforceBF16, enableSnippets, isLegacyAPI());
auto ops = clonnedFunction->get_ordered_ops();
//Mark removed nodes as supported