[CPU] Support MHA optimization (#12936)

This commit is contained in:
Gorokhov Dmitriy 2022-09-09 10:03:19 +04:00 committed by GitHub
parent 0dd1f6e1b0
commit e7fe00f5f2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 3153 additions and 5 deletions

View File

@ -97,6 +97,13 @@ INFERENCE_ENGINE_API_CPP(bool) with_cpu_x86_avx512f();
*/ */
INFERENCE_ENGINE_API_CPP(bool) with_cpu_x86_avx512_core(); INFERENCE_ENGINE_API_CPP(bool) with_cpu_x86_avx512_core();
/**
* @brief Checks whether CPU supports AVX 512 VNNI capability
* @ingroup ie_dev_api_system_conf
* @return `True` is AVX512F, AVX512BW, AVX512DQ, AVX512_VNNI instructions are available, `false` otherwise
*/
INFERENCE_ENGINE_API_CPP(bool) with_cpu_x86_avx512_core_vnni();
/** /**
* @brief Checks whether CPU supports BFloat16 capability * @brief Checks whether CPU supports BFloat16 capability
* @ingroup ie_dev_api_system_conf * @ingroup ie_dev_api_system_conf

View File

@ -41,6 +41,10 @@ bool with_cpu_x86_avx512_core() {
return get_cpu_info().has(Xbyak::util::Cpu::tAVX512F | Xbyak::util::Cpu::tAVX512DQ | Xbyak::util::Cpu::tAVX512BW); return get_cpu_info().has(Xbyak::util::Cpu::tAVX512F | Xbyak::util::Cpu::tAVX512DQ | Xbyak::util::Cpu::tAVX512BW);
} }
bool with_cpu_x86_avx512_core_vnni() {
return with_cpu_x86_avx512_core() && get_cpu_info().has(Xbyak::util::Cpu::tAVX512_VNNI);
}
bool with_cpu_x86_bfloat16() { bool with_cpu_x86_bfloat16() {
return get_cpu_info().has(Xbyak::util::Cpu::tAVX512_BF16); return get_cpu_info().has(Xbyak::util::Cpu::tAVX512_BF16);
} }

View File

@ -197,6 +197,7 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
{ "Subgraph", Type::Subgraph}, { "Subgraph", Type::Subgraph},
{ "PriorBox", Type::PriorBox}, { "PriorBox", Type::PriorBox},
{ "PriorBoxClustered", Type::PriorBoxClustered}, { "PriorBoxClustered", Type::PriorBoxClustered},
{ "MHA", Type::MHA},
}; };
Type TypeFromName(const std::string& type) { Type TypeFromName(const std::string& type) {
@ -388,6 +389,8 @@ std::string NameFromType(const Type type) {
return "Reference"; return "Reference";
case Type::Subgraph: case Type::Subgraph:
return "Subgraph"; return "Subgraph";
case Type::MHA:
return "MHA";
default: default:
return "Unknown"; return "Unknown";
} }

View File

@ -108,6 +108,7 @@ enum class Type {
Subgraph, Subgraph,
PriorBox, PriorBox,
PriorBoxClustered, PriorBoxClustered,
MHA
}; };
enum class Algorithm { enum class Algorithm {

View File

@ -7,6 +7,7 @@
#include "ngraph_transformations/op/leaky_relu.hpp" #include "ngraph_transformations/op/leaky_relu.hpp"
#include "ngraph_transformations/op/power_static.hpp" #include "ngraph_transformations/op/power_static.hpp"
#include "ngraph_transformations/op/swish_cpu.hpp" #include "ngraph_transformations/op/swish_cpu.hpp"
#include "ngraph_transformations/op/mha.hpp"
#include "snippets_transformations/op/load_convert.hpp" #include "snippets_transformations/op/load_convert.hpp"
#include "snippets_transformations/op/store_convert.hpp" #include "snippets_transformations/op/store_convert.hpp"
@ -44,6 +45,7 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
NGRAPH_OP(LeakyReluNode, ov::intel_cpu) NGRAPH_OP(LeakyReluNode, ov::intel_cpu)
NGRAPH_OP(PowerStaticNode, ov::intel_cpu) NGRAPH_OP(PowerStaticNode, ov::intel_cpu)
NGRAPH_OP(SwishNode, ov::intel_cpu) NGRAPH_OP(SwishNode, ov::intel_cpu)
NGRAPH_OP(MHANode, ov::intel_cpu)
NGRAPH_OP(LoadConvertSaturation, ov::intel_cpu) NGRAPH_OP(LoadConvertSaturation, ov::intel_cpu)
NGRAPH_OP(LoadConvertTruncation, ov::intel_cpu) NGRAPH_OP(LoadConvertTruncation, ov::intel_cpu)
NGRAPH_OP(StoreConvertSaturation, ov::intel_cpu) NGRAPH_OP(StoreConvertSaturation, ov::intel_cpu)

View File

@ -0,0 +1,644 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "mha_fusion.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "op/mha.hpp"
#include "itt.hpp"
// TODO: draw pattern
ov::intel_cpu::MHAFloatFusion::MHAFloatFusion() {
MATCHER_SCOPE(MHAFloatFusion);
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
auto mul = std::make_shared<ngraph::opset3::Multiply>(transpose1, in2);
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, mul);
auto add = std::make_shared<ngraph::opset4::Add>(matmul0, in3);
auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, in6, true);
auto softmax = std::make_shared<ngraph::opset1::Softmax>(reshape0);
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softmax, in7, true);
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(reshape1, transpose2);
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(matmul1, in10);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto transpose0_in = pattern_to_output.at(in0);
auto transpose1_in = pattern_to_output.at(in1);
auto mul_in1 = pattern_to_output.at(in2);
auto add_in1 = pattern_to_output.at(in3);
auto transpose2_in = pattern_to_output.at(in8);
if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) {
return false;
}
if (transpose0_in.get_shape().size() != 4) {
return false;
}
auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]});
if (add_in1.get_shape() != expected_add_shape) {
return false;
}
if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false;
if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
std::vector<float> mul_scales;
if (auto mul_node = ngraph::as_type_ptr<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
mul_scales = ngraph::as_type_ptr<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
auto expected_shape = ngraph::Shape({1, transpose0_in.get_shape()[2], 1, 1});
if (mul_scales.size() != 1 && mul_node->get_input_shape(1) != expected_shape) {
return false;
}
} else {
return false;
}
auto matmul0_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
if (!matmul0_node)
return false;
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
return false;
auto reshape0_node = ngraph::as_type_ptr<ngraph::opset1::Reshape>(pattern_to_output.at(reshape0).get_node_shared_ptr());
if (!reshape0_node)
return false;
if (auto reshape_pattern = ngraph::as_type_ptr<ngraph::opset4::Constant>(pattern_to_output.at(in6).get_node_shared_ptr())) {
if (reshape0_node->get_input_shape(0).size() != 4) {
return false;
}
std::vector<int64_t> reshapeConstData = {static_cast<int64_t>(reshape0_node->get_input_shape(0)[0] *
reshape0_node->get_input_shape(0)[1] *
reshape0_node->get_input_shape(0)[2]),
-1};
if (reshape_pattern->cast_vector<int64_t>() != reshapeConstData) {
return false;
}
} else {
return false;
}
if (auto reshape1_node = ngraph::as_type_ptr<ngraph::opset1::Reshape>(pattern_to_output.at(reshape1).get_node_shared_ptr())) {
if (reshape0_node->get_input_shape(0) != reshape1_node->get_output_shape(0)) {
return false;
}
} else {
return false;
}
auto softmax_node = ngraph::as_type_ptr<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
if (!softmax_node)
return false;
if (softmax_node->get_axis() != 1)
return false;
auto matmul1_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
if (!matmul1_node)
return false;
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
return false;
bool is_mul_first = true;
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
transpose3_node->get_output_element_type(0));
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
pattern_to_output.at(transpose1).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(matmul0).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(reshape0).get_node_shared_ptr(),
pattern_to_output.at(softmax).get_node_shared_ptr(),
pattern_to_output.at(reshape1).get_node_shared_ptr(),
pattern_to_output.at(transpose2).get_node_shared_ptr(),
pattern_to_output.at(matmul1).get_node_shared_ptr(),
pattern_to_output.at(transpose3).get_node_shared_ptr(),
},
mha);
if (transformation_callback(mha)) {
return false;
}
ngraph::replace_node(m.get_match_root(), mha);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
this->register_matcher(m, callback);
}
ov::intel_cpu::MHAFloatFusion2::MHAFloatFusion2() {
MATCHER_SCOPE(MHAFloatFusion2);
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1);
auto add = std::make_shared<ngraph::opset4::Add>(matmul0, in3);
auto softmax = std::make_shared<ngraph::opset1::Softmax>(add);
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, transpose2);
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(matmul1, in10);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto transpose0_in = pattern_to_output.at(in0);
auto transpose1_in = pattern_to_output.at(in1);
auto add_in1 = pattern_to_output.at(in3);
auto transpose2_in = pattern_to_output.at(in8);
if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) {
return false;
}
if (transpose0_in.get_shape().size() != 4) {
return false;
}
auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]});
if (add_in1.get_shape() != expected_add_shape) {
return false;
}
if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false;
if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
auto matmul0_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
if (!matmul0_node)
return false;
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
return false;
auto softmax_node = ngraph::as_type_ptr<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
if (!softmax_node)
return false;
if (softmax_node->get_axis() != 3)
return false;
auto matmul1_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
if (!matmul1_node)
return false;
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
return false;
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, std::vector<float>(), false,
transpose3_node->get_output_element_type(0));
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
pattern_to_output.at(transpose1).get_node_shared_ptr(),
pattern_to_output.at(matmul0).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(softmax).get_node_shared_ptr(),
pattern_to_output.at(transpose2).get_node_shared_ptr(),
pattern_to_output.at(matmul1).get_node_shared_ptr(),
pattern_to_output.at(transpose3).get_node_shared_ptr(),
},
mha);
if (transformation_callback(mha)) {
return false;
}
ngraph::replace_node(m.get_match_root(), mha);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
this->register_matcher(m, callback);
}
static std::vector<float> simplifyToScale(const std::shared_ptr<ngraph::opset1::FakeQuantize>& fq_node) {
auto levels = fq_node->get_levels();
auto input_low = ngraph::as_type_ptr<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(1))->cast_vector<float>();
auto input_high = ngraph::as_type_ptr<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(2))->cast_vector<float>();
auto output_low = ngraph::as_type_ptr<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(3))->cast_vector<float>();
auto output_high = ngraph::as_type_ptr<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(4))->cast_vector<float>();
std::vector<float> cl, ch, isc, ish, osc, osh;
for (int i = 0; i < input_low.size(); i++) {
cl.push_back(input_low[i]);
}
for (int i = 0; i < input_high.size(); i++) {
ch.push_back(input_high[i]);
}
for (int i = 0; i < std::max(input_low.size(), input_high.size()); i++) {
float il = input_low[input_low.size() == 1 ? 0 : i];
float ih = input_high[input_high.size() == 1 ? 0 : i];
isc.push_back((levels - 1) / (ih - il));
ish.push_back(-il * (levels - 1) / (ih - il));
}
for (int i = 0; i < std::max(output_low.size(), output_high.size()); i++) {
float ol = output_low[output_low.size() == 1 ? 0 : i];
float oh = output_high[output_high.size() == 1 ? 0 : i];
osc.push_back((oh - ol) / (levels - 1));
osh.push_back(ol);
}
std::vector<float> outScale;
if (fq_node->get_output_element_type(0) == ngraph::element::u8 &&
std::all_of(cl.cbegin(), cl.cend(), [](float val) { return val == 0.0f; }) &&
std::all_of(ish.cbegin(), ish.cend(), [](float val) { return val == 0.0f; }) &&
std::all_of(osc.cbegin(), osc.cend(), [](float val) { return val == 1.0f; }) &&
std::all_of(osh.cbegin(), osh.cend(), [](float val) { return val == 0.0f; })) {
outScale = isc;
}
if (fq_node->get_output_element_type(0) == ngraph::element::i8 &&
std::all_of(ish.cbegin(), ish.cend(), [](float val) { return std::abs(val - 128.f) < 0.0001f; }) &&
std::all_of(osc.cbegin(), osc.cend(), [](float val) { return val == 1.f; }) &&
std::all_of(osh.cbegin(), osh.cend(), [](float val) { return std::abs(val + 128.f) < 0.0001f; })) {
bool isCropAligned = true;
for (int i = 0; i < std::max(cl.size(), isc.size()); i++) {
if (std::abs(cl[cl.size() == 1 ? 0 : i] * isc[isc.size() == 1 ? 0 : i] + 128.f) > 0.0001f) {
isCropAligned = false;
}
}
for (int i = 0; i < std::max(ch.size(), isc.size()); i++) {
if (std::abs(ch[ch.size() == 1 ? 0 : i] * isc[isc.size() == 1 ? 0 : i] - 127.f) > 0.0001f) {
isCropAligned = false;
}
}
if (isCropAligned) {
outScale = isc;
}
}
return outScale;
}
// TODO: draw pattern
ov::intel_cpu::MHAQuantFusion::MHAQuantFusion() {
MATCHER_SCOPE(MHAQuantFusion);
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1);
auto fakeQuantize0 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul0,
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
auto add = std::make_shared<ngraph::opset4::Add>(fakeQuantize0, in3);
auto mul = std::make_shared<ngraph::opset3::Multiply>(add, in2);
auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(mul, in6, true);
auto softmax = std::make_shared<ngraph::opset1::Softmax>(reshape0);
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softmax, in7, true);
auto fakeQuantize1 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({reshape1,
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(fakeQuantize1, transpose2);
auto fakeQuantize2 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul1,
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(fakeQuantize2, in10);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto transpose0_in = pattern_to_output.at(in0);
auto transpose1_in = pattern_to_output.at(in1);
auto add_in1 = pattern_to_output.at(in3);
auto transpose2_in = pattern_to_output.at(in8);
if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) {
return false;
}
if (transpose0_in.get_shape().size() != 4) {
return false;
}
auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]});
if (add_in1.get_shape() != expected_add_shape) {
return false;
}
std::vector<float> mul_scales;
if (auto mul_node = ngraph::as_type_ptr<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
mul_scales = ngraph::as_type_ptr<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
auto expected_shape = ngraph::Shape({1, transpose0_in.get_shape()[2], 1, 1});
if (mul_scales.size() != 1 && mul_node->get_input_shape(1) != expected_shape) {
return false;
}
} else {
return false;
}
if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false;
if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
auto matmul0_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
if (!matmul0_node)
return false;
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
return false;
std::vector<float> fq0_scale;
auto fq0_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize0).get_node_shared_ptr());
if (fq0_node) {
fq0_scale = simplifyToScale(fq0_node);
if (!fq0_scale.size())
return false;
}
auto reshape0_node = ngraph::as_type_ptr<ngraph::opset1::Reshape>(pattern_to_output.at(reshape0).get_node_shared_ptr());
if (!reshape0_node)
return false;
if (auto reshape_pattern = ngraph::as_type_ptr<ngraph::opset4::Constant>(pattern_to_output.at(in6).get_node_shared_ptr())) {
if (reshape0_node->get_input_shape(0).size() != 4) {
return false;
}
std::vector<int64_t> reshapeConstData = {static_cast<int64_t>(reshape0_node->get_input_shape(0)[0] *
reshape0_node->get_input_shape(0)[1] *
reshape0_node->get_input_shape(0)[2]),
-1};
if (reshape_pattern->cast_vector<int64_t>() != reshapeConstData) {
return false;
}
} else {
return false;
}
if (auto reshape1_node = ngraph::as_type_ptr<ngraph::opset1::Reshape>(pattern_to_output.at(reshape1).get_node_shared_ptr())) {
if (reshape0_node->get_input_shape(0) != reshape1_node->get_output_shape(0)) {
return false;
}
} else {
return false;
}
auto softmax_node = ngraph::as_type_ptr<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
if (!softmax_node)
return false;
if (softmax_node->get_axis() != 1)
return false;
std::vector<float> fq1_scale;
auto fq1_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr());
if (fq1_node) {
fq1_scale = simplifyToScale(fq1_node);
if (!fq1_scale.size())
return false;
} else {
return false;
}
auto matmul1_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
if (!matmul1_node)
return false;
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
return false;
std::vector<float> fq2_scale;
if (auto fq_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize2).get_node_shared_ptr())) {
fq2_scale = simplifyToScale(fq_node);
if (!fq2_scale.size())
return false;
}
bool is_mul_first = false;
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
std::vector<float>(), fq0_scale, fq1_scale, fq2_scale,
ngraph::element::undefined,
fq0_node ? fq0_node->get_output_element_type(0) : ngraph::element::undefined,
fq1_node->get_output_element_type(0), transpose3_node->get_output_element_type(0));
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
pattern_to_output.at(transpose1).get_node_shared_ptr(),
pattern_to_output.at(matmul0).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(reshape0).get_node_shared_ptr(),
pattern_to_output.at(softmax).get_node_shared_ptr(),
pattern_to_output.at(reshape1).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(),
pattern_to_output.at(transpose2).get_node_shared_ptr(),
pattern_to_output.at(matmul1).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize2).get_node_shared_ptr(),
pattern_to_output.at(transpose3).get_node_shared_ptr(),
},
mha);
if (transformation_callback(mha)) {
return false;
}
ngraph::replace_node(m.get_match_root(), mha);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
this->register_matcher(m, callback);
}
// TODO: draw pattern
ov::intel_cpu::MHAQuantFusion2::MHAQuantFusion2() {
MATCHER_SCOPE(MHAQuantFusion2);
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
auto fakeQuantize0 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({transpose1,
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, fakeQuantize0);
auto mul = std::make_shared<ngraph::opset3::Multiply>(matmul0, in2);
auto add = std::make_shared<ngraph::opset4::Add>(mul, in3);
auto softmax = std::make_shared<ngraph::opset1::Softmax>(add);
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, transpose2);
auto fakeQuantize1 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul1,
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(fakeQuantize1, in10);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
auto transpose0_in = pattern_to_output.at(in0);
auto transpose1_in = pattern_to_output.at(in1);
auto add_in1 = pattern_to_output.at(in3);
auto transpose2_in = pattern_to_output.at(in8);
if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) {
return false;
}
if (transpose0_in.get_shape().size() != 4) {
return false;
}
auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]});
if (add_in1.get_shape() != expected_add_shape) {
return false;
}
std::vector<float> mul_scales;
if (auto mul_node = ngraph::as_type_ptr<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
mul_scales = ngraph::as_type_ptr<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
auto expected_shape = ngraph::Shape({1, transpose0_in.get_shape()[2], 1, 1});
if (mul_scales.size() != 1 && mul_node->get_input_shape(1) != expected_shape) {
return false;
}
} else {
return false;
}
if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false;
if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
auto matmul0_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
if (!matmul0_node)
return false;
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
return false;
std::vector<float> fq0_scale;
auto fq0_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize0).get_node_shared_ptr());
if (fq0_node) {
fq0_scale = simplifyToScale(fq0_node);
if (!fq0_scale.size())
return false;
} else {
return false;
}
auto softmax_node = ngraph::as_type_ptr<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
if (!softmax_node)
return false;
if (softmax_node->get_axis() != 3)
return false;
std::vector<float> fq1_scale;
if (auto fq_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
fq1_scale = simplifyToScale(fq_node);
if (!fq1_scale.size())
return false;
}
auto matmul1_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
if (!matmul1_node)
return false;
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
return false;
bool is_mul_first = true;
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
fq0_scale, std::vector<float>(), std::vector<float>(), fq1_scale,
fq0_node->get_output_element_type(0), ngraph::element::undefined, ngraph::element::undefined,
transpose3_node->get_output_element_type(0));
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
pattern_to_output.at(transpose1).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
pattern_to_output.at(matmul0).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(softmax).get_node_shared_ptr(),
pattern_to_output.at(transpose2).get_node_shared_ptr(),
pattern_to_output.at(matmul1).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(),
pattern_to_output.at(transpose3).get_node_shared_ptr(),
},
mha);
if (transformation_callback(mha)) {
return false;
}
ngraph::replace_node(m.get_match_root(), mha);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,64 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/pass/graph_rewrite.hpp>
#include <ngraph/opsets/opset4.hpp>
namespace ov {
namespace intel_cpu {
class MHAFusionBase : public ngraph::pass::MatcherPass {
protected:
bool valid_transpose_order(const std::shared_ptr<ngraph::Node>& node, const std::vector<int64_t>& expected_order) {
if (auto transpose_pattern = ngraph::as_type_ptr<ngraph::opset4::Constant>(node)) {
if (transpose_pattern->cast_vector<int64_t>() != expected_order) {
return false;
}
} else {
return false;
}
return true;
}
};
class MHAFloatFusion: public MHAFusionBase {
public:
OPENVINO_RTTI("MHAFloatFusion", "0");
MHAFloatFusion();
};
class MHAFloatFusion2: public MHAFusionBase {
public:
OPENVINO_RTTI("MHAFloatFusion2", "0");
MHAFloatFusion2();
};
class MHAQuantFusion: public MHAFusionBase {
public:
OPENVINO_RTTI("MHAQuantFusion", "0");
MHAQuantFusion();
};
class MHAQuantFusion2: public MHAFusionBase {
public:
OPENVINO_RTTI("MHAQuantFusion2", "0");
MHAQuantFusion2();
};
class MHAFusion : public ngraph::pass::GraphRewrite {
public:
OPENVINO_RTTI("MHAFusion", "0");
MHAFusion() {
add_matcher<MHAFloatFusion>();
add_matcher<MHAFloatFusion2>();
add_matcher<MHAQuantFusion>();
add_matcher<MHAQuantFusion2>();
}
};
} // namespace intel_cpu
} // namespace ov

View File

@ -0,0 +1,108 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "mha.hpp"
#include "../itt.hpp"
#include <ngraph/opsets/opset3.hpp>
#include <matmul_shape_inference.hpp>
ov::intel_cpu::MHANode::MHANode(const ngraph::Output<ngraph::Node> &in0,
const ngraph::Output<ngraph::Node> &in1,
const ngraph::Output<ngraph::Node> &in2,
const ngraph::Output<ngraph::Node> &in3,
const std::vector<float> &mul_scales,
bool is_mul_first,
const ngraph::element::Type output_type)
: Op({in0, in1, in2, in3}), m_output_type(output_type) {
this->mul_scales = mul_scales;
this->is_mul_first = is_mul_first;
this->fq0_output_type = ngraph::element::undefined;
this->fq1_output_type = ngraph::element::undefined;
this->fq2_output_type = ngraph::element::undefined;
validate_and_infer_types();
}
ov::intel_cpu::MHANode::MHANode(const ngraph::Output<ngraph::Node> &in0,
const ngraph::Output<ngraph::Node> &in1,
const ngraph::Output<ngraph::Node> &in2,
const ngraph::Output<ngraph::Node> &in3,
const std::vector<float> &mul_scales,
bool is_mul_first,
const std::vector<float> &fq_scales0,
const std::vector<float> &fq_scales1,
const std::vector<float> &fq_scales2,
const std::vector<float> &fq_scales3,
const ngraph::element::Type fq0_output_type,
const ngraph::element::Type fq1_output_type,
const ngraph::element::Type fq2_output_type,
const ngraph::element::Type output_type)
: Op({in0, in1, in2, in3}), m_output_type(output_type) {
this->mul_scales = mul_scales;
this->is_mul_first = is_mul_first;
this->fq_scales0 = fq_scales0;
this->fq_scales1 = fq_scales1;
this->fq_scales2 = fq_scales2;
this->fq_scales3 = fq_scales3;
this->fq0_output_type = fq0_output_type;
this->fq1_output_type = fq1_output_type;
this->fq2_output_type = fq2_output_type;
validate_and_infer_types();
}
std::shared_ptr<ngraph::Node> ov::intel_cpu::MHANode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
INTERNAL_OP_SCOPE(MHANode_clone_with_new_inputs);
check_new_args_count(this, new_args);
return std::make_shared<ov::intel_cpu::MHANode>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
mul_scales, is_mul_first, fq_scales0, fq_scales1, fq_scales2, fq_scales3,
fq0_output_type, fq1_output_type, fq2_output_type, m_output_type);
}
void ov::intel_cpu::MHANode::validate_and_infer_types() {
INTERNAL_OP_SCOPE(MHANode_validate_and_infer_types);
auto transpose = [](const ov::Shape& shape, const std::vector<size_t>& order) -> ov::Shape {
std::vector<size_t> new_shape(shape.size());
for (int i = 0; i < shape.size(); i++) {
new_shape[i] = shape[order[i]];
}
return new_shape;
};
const auto matmul0_shape0 = transpose(get_input_partial_shape(0).get_shape(), {0, 2, 1, 3});
const auto matmul0_shape1 = transpose(get_input_partial_shape(1).get_shape(), {0, 2, 3, 1});
auto matmul0_in0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul0_shape0);
auto matmul0_in1 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul0_shape1);
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(matmul0_in0, matmul0_in1);
std::vector<ov::PartialShape> matmul0_input_shapes = {matmul0_shape0, matmul0_shape1};
std::vector<ov::PartialShape> matmul0_output_shapes = {ov::PartialShape{}};
shape_infer(matmul0.get(), matmul0_input_shapes, matmul0_output_shapes);
const auto matmul1_shape0 = matmul0_output_shapes[0];
const auto matmul1_shape1 = transpose(get_input_partial_shape(3).get_shape(), {0, 2, 1, 3});
auto matmul1_in0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul1_shape0);
auto matmul1_in1 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul1_shape1);
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(matmul1_in0, matmul1_in1);
std::vector<ov::PartialShape> matmul1_input_shapes = {matmul1_shape0, matmul1_shape1};
std::vector<ov::PartialShape> matmul1_output_shapes = {ov::PartialShape{}};
shape_infer(matmul1.get(), matmul1_input_shapes, matmul1_output_shapes);
const auto output_shape = transpose(matmul1_output_shapes[0].get_shape(), {0, 2, 1, 3});
set_output_type(
0,
m_output_type == ngraph::element::undefined || m_output_type == ngraph::element::dynamic ? get_input_element_type(0) : m_output_type,
output_shape);
}
bool ov::intel_cpu::MHANode::visit_attributes(ngraph::AttributeVisitor &visitor) {
INTERNAL_OP_SCOPE(MHANode_visit_attributes);
visitor.on_attribute("out-type", m_output_type);
return true;
}

View File

@ -0,0 +1,94 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ngraph/op/op.hpp>
namespace ov {
namespace intel_cpu {
class MHANode : public ngraph::op::Op {
public:
OPENVINO_OP("MHA", "cpu_plugin_opset");
MHANode() = default;
MHANode(const ngraph::Output<ngraph::Node> &in0,
const ngraph::Output<ngraph::Node> &in1,
const ngraph::Output<ngraph::Node> &in2,
const ngraph::Output<ngraph::Node> &in3,
const std::vector<float> &mul_scales,
bool is_mul_first,
const ngraph::element::Type output_type);
MHANode(const ngraph::Output<ngraph::Node> &in0,
const ngraph::Output<ngraph::Node> &in1,
const ngraph::Output<ngraph::Node> &in2,
const ngraph::Output<ngraph::Node> &in3,
const std::vector<float> &mul_scales,
bool is_mul_first,
const std::vector<float> &fq_scales0,
const std::vector<float> &fq_scales1,
const std::vector<float> &fq_scales2,
const std::vector<float> &fq_scales3,
const ngraph::element::Type fq0_output_type,
const ngraph::element::Type fq1_output_type,
const ngraph::element::Type fq2_output_type,
const ngraph::element::Type output_type);
void validate_and_infer_types() override;
bool visit_attributes(ngraph::AttributeVisitor &visitor) override;
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector &new_args) const override;
ngraph::element::Type get_output_type() const { return m_output_type; }
const std::vector<float>& get_mul_scales() const {
return mul_scales;
}
const std::vector<float>& get_fq_scales0() const {
return fq_scales0;
}
const std::vector<float>& get_fq_scales1() const {
return fq_scales1;
}
const std::vector<float>& get_fq_scales2() const {
return fq_scales2;
}
const std::vector<float>& get_fq_scales3() const {
return fq_scales3;
}
bool get_is_mul_first() const {
return is_mul_first;
}
ngraph::element::Type get_fq0_output_type() const {
return fq0_output_type;
}
ngraph::element::Type get_fq1_output_type() const {
return fq1_output_type;
}
ngraph::element::Type get_fq2_output_type() const {
return fq2_output_type;
}
private:
ngraph::element::Type m_output_type;
std::vector<float> mul_scales;
bool is_mul_first;
std::vector<float> fq_scales0;
std::vector<float> fq_scales1;
std::vector<float> fq_scales2;
std::vector<float> fq_scales3;
ngraph::element::Type fq0_output_type;
ngraph::element::Type fq1_output_type;
ngraph::element::Type fq2_output_type;
};
} // namespace intel_cpu
} // namespace ov

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,243 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <node.h>
#include <memory>
#include <string>
#include <vector>
#include <cpu/x64/brgemm/brgemm.hpp>
#include <cpu/x64/matmul/brgemm_matmul_copy_utils.hpp>
#include <cpu/x64/matmul/brgemm_matmul_utils.hpp>
#include <cpu/x64/amx_tile_configure.hpp>
namespace ov {
namespace intel_cpu {
namespace node {
struct jit_mul_add_softmax_compile_params {
InferenceEngine::Precision src_prc;
InferenceEngine::Precision dst_prc;
size_t work_amount;
bool with_mul_scales;
bool is_mul_first;
bool with_scales0;
bool broadcast_scales0;
bool with_scales1;
bool broadcast_scales1;
};
struct jit_mul_add_softmax_call_args {
const void *p_in0;
const void *p_mul_in1;
const void *p_add_in1;
void *p_out;
void *p_buffer;
const void *p_scales0;
const void *p_scales1;
};
struct jit_uni_mul_add_softmax_kernel {
void (*ker_)(const jit_mul_add_softmax_call_args*);
void operator()(const jit_mul_add_softmax_call_args* call_args) {
assert(ker_);
ker_(call_args);
}
explicit jit_uni_mul_add_softmax_kernel(const jit_mul_add_softmax_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
virtual ~jit_uni_mul_add_softmax_kernel() {}
virtual void create_ker() = 0;
jit_mul_add_softmax_compile_params jcp_;
};
struct jit_convert_reorder_compile_params {
InferenceEngine::Precision src_prc;
InferenceEngine::Precision dst_prc;
size_t inner_work_amount;
bool with_scales;
bool broadcast_scales;
size_t src_stride;
size_t dst_stride;
};
struct jit_convert_reorder_call_args {
const void *p_in;
void *p_out;
const void *p_scales;
size_t outter_work_amount;
};
struct jit_uni_convert_reorder_kernel {
void (*ker_)(const jit_convert_reorder_call_args*);
void operator()(const jit_convert_reorder_call_args* call_args) {
assert(ker_);
ker_(call_args);
}
explicit jit_uni_convert_reorder_kernel(const jit_convert_reorder_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
virtual ~jit_uni_convert_reorder_kernel() {}
virtual void create_ker() = 0;
jit_convert_reorder_compile_params jcp_;
};
struct jit_convert_transpose_compile_params {
InferenceEngine::Precision src_prc;
InferenceEngine::Precision dst_prc;
size_t inner_work_amount;
size_t outter_work_amount;
bool with_scales;
bool broadcast_scales;
size_t inner_src_stride;
size_t outter_src_stride;
size_t outter_dst_stride;
};
struct jit_convert_transpose_call_args {
const void *p_in;
void *p_out;
const void *p_scales;
};
struct jit_uni_convert_transpose_kernel {
void (*ker_)(const jit_convert_transpose_call_args*);
void operator()(const jit_convert_transpose_call_args* call_args) {
assert(ker_);
ker_(call_args);
}
explicit jit_uni_convert_transpose_kernel(const jit_convert_transpose_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
virtual ~jit_uni_convert_transpose_kernel() {}
virtual void create_ker() = 0;
jit_convert_transpose_compile_params jcp_;
};
#define MHA_BRGEMM_KERNELS_NUM 8
class MHA : public Node {
public:
MHA(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache);
void getSupportedDescriptors() override {};
void initSupportedPrimitiveDescriptors() override;
void execute(dnnl::stream strm) override;
bool created() const override;
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
protected:
void executeDynamicImpl(dnnl::stream strm) override;
void prepareParams() override;
private:
struct brgemmCtx {
size_t M, N, K, LDA, LDB, LDC;
dnnl_data_type_t dt_in0, dt_in1;
char palette[64];
bool is_with_amx;
bool is_with_comp;
float beta;
};
template <typename in1_type>
void mhaImpl();
void init_brgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel, bool use_amx);
void init_brgemm_copy_a(std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_a_t>& brgCopyKernel,
size_t K, size_t K_blk, size_t K_tail, size_t LDA, dnnl_data_type_t dt_in0);
void init_brgemm_copy_b(std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t>& brgCopyKernel,
size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1);
void callBrgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel,
const void* pin0, const void* pin1, void* pout, void* wsp);
size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) {
return mIdx * 4 + kIdx * 2 + nIdx;
}
std::vector<InferenceEngine::Precision> inputPrecisions;
InferenceEngine::Precision accPrecision0;
InferenceEngine::Precision accPrecision1;
VectorDims dimsTranspose0In0;
VectorDims dimsTranspose1In0;
VectorDims dimsMulIn1;
VectorDims dimsAddIn1;
VectorDims dimsTranspose2In0;
VectorDims dimsOut;
VectorDims strTranspose0In0;
VectorDims strTranspose1In0;
VectorDims strMulIn1;
VectorDims strAddIn1;
VectorDims strTranspose2In0;
VectorDims strOut;
VectorDims dimsMatMul0In0;
VectorDims dimsMatMul0In1;
VectorDims dimsMatMul0Out;
VectorDims dimsMatMul1In1;
VectorDims dimsMatMul1Out;
size_t batch0, batch1;
size_t M, M_blk, M_tail;
size_t K0, K0_blk, K0_tail, N0, N0_blk, N0_tail;
size_t K1, K1_blk, K1_tail, N1, N1_blk, N1_tail;
size_t bufferMatMul0In0Size;
size_t bufferMatMul0In1Size;
size_t bufferMatMul0OutSize;
size_t bufferMatMul1In1Size;
size_t bufferMatMul1OutSize;
size_t bufferCompensation0Size;
size_t bufferCompensation1Size;
size_t wsp_size_per_thread = 4 * 1024;
std::vector<uint8_t> bufferMatMul0In0;
std::vector<uint8_t> bufferMatMul0In1;
std::vector<uint8_t> bufferMatMul0Out;
std::vector<uint8_t> bufferMatMul1In1;
std::vector<uint8_t> bufferMatMul1Out;
std::vector<int32_t> bufferCompensation0;
std::vector<int32_t> bufferCompensation1;
std::vector<size_t> wsp;
bool isMulFirst;
InferenceEngine::Precision fqPrc2;
std::vector<float> mulScales;
std::vector<float> fqScales0;
std::vector<float> fqScales1;
std::vector<float> fqScales2;
std::vector<float> fqScales3;
size_t brg0VnniFactor;
brgemmCtx brgCtxs0[MHA_BRGEMM_KERNELS_NUM];
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> brgKernels0[MHA_BRGEMM_KERNELS_NUM];
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_a_t> brgCopyAKernel0;
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t> brgCopyBKernel0;
size_t brg1VnniFactor;
brgemmCtx brgCtxs1[MHA_BRGEMM_KERNELS_NUM];
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> brgKernels1[MHA_BRGEMM_KERNELS_NUM];
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t> brgCopyBKernel1;
std::unique_ptr<jit_uni_mul_add_softmax_kernel> mulAddSoftmaxKernel;
std::unique_ptr<jit_uni_convert_reorder_kernel> convertReorderKernel;
std::unique_ptr<jit_uni_convert_transpose_kernel> convertTransposeKernel;
};
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -88,6 +88,7 @@
#include "nodes/priorbox.h" #include "nodes/priorbox.h"
#include "nodes/priorbox_clustered.h" #include "nodes/priorbox_clustered.h"
#include "nodes/eye.h" #include "nodes/eye.h"
#include "nodes/mha.h"
namespace ov { namespace ov {
namespace intel_cpu { namespace intel_cpu {
@ -188,6 +189,7 @@ Node::NodesFactory::NodesFactory()
INTEL_CPU_NODE(PriorBox, Type::PriorBox); INTEL_CPU_NODE(PriorBox, Type::PriorBox);
INTEL_CPU_NODE(PriorBoxClustered, Type::PriorBoxClustered); INTEL_CPU_NODE(PriorBoxClustered, Type::PriorBoxClustered);
INTEL_CPU_NODE(Eye, Type::Eye); INTEL_CPU_NODE(Eye, Type::Eye);
INTEL_CPU_NODE(MHA, Type::MHA);
} }
#undef INTEL_CPU_NODE #undef INTEL_CPU_NODE

View File

@ -87,6 +87,7 @@
#include <transformations/op_conversions/convert_roi_align_v3_to_v9.hpp> #include <transformations/op_conversions/convert_roi_align_v3_to_v9.hpp>
#include <transformations/op_conversions/softsign_decomposition.hpp> #include <transformations/op_conversions/softsign_decomposition.hpp>
#include "transformations/op_conversions/eye_decomposition.hpp" #include "transformations/op_conversions/eye_decomposition.hpp"
#include "ngraph_transformations/mha_fusion.hpp"
#include <ngraph/opsets/opset1.hpp> #include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset2.hpp> #include <ngraph/opsets/opset2.hpp>
@ -118,6 +119,7 @@
#include "nodes/mvn.h" #include "nodes/mvn.h"
#include "nodes/fake_quantize.h" #include "nodes/fake_quantize.h"
#include "nodes/normalize.h" #include "nodes/normalize.h"
#include "nodes/mha.h"
#include "ngraph_transformations/convert_to_cpu_specific_opset.hpp" #include "ngraph_transformations/convert_to_cpu_specific_opset.hpp"
#include "ngraph_transformations/move_eltwise_up_data_movement.hpp" #include "ngraph_transformations/move_eltwise_up_data_movement.hpp"
#include "transformations/smart_reshape/smart_reshape.hpp" #include "transformations/smart_reshape/smart_reshape.hpp"
@ -249,7 +251,7 @@ Engine::~Engine() {
executorManager()->clear("CPUCallbackExecutor"); executorManager()->clear("CPUCallbackExecutor");
} }
static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function> nGraphFunc, const bool _enableLPT, static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function> nGraphFunc, const bool _enableLPT, const bool _enableBF16,
const bool _enableSnippets, const bool isLegacyApi) { const bool _enableSnippets, const bool isLegacyApi) {
ngraph::pass::Manager manager; ngraph::pass::Manager manager;
manager.set_per_pass_validation(false); manager.set_per_pass_validation(false);
@ -600,6 +602,27 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
}); });
postLPTPassManager.register_pass<ngraph::pass::ConstantFolding>(); postLPTPassManager.register_pass<ngraph::pass::ConstantFolding>();
// Snippets may brake MHA patterns so the fusion has to performed before
postLPTPassManager.register_pass<MHAFusion>();
postLPTPassManager.get_pass_config()->set_callback<MHAFloatFusion, MHAFloatFusion2,
MHAQuantFusion, MHAQuantFusion2>([_enableBF16](const std::shared_ptr<const ov::Node>& n) -> bool {
std::string errorMessage;
if (!node::MHA::isSupportedOperation(n, errorMessage))
return true;
// Implementation calls AMX BF16 brgemm only for tensors with K and N aligned on 2, otherwise fallbacks on vector impl
// Vector madd BF16 instruction on SPR has reduced performance on HW level, which results in overall perf degradation
size_t bf16Factor = 2;
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16_amx_bf16) &&
(n->get_input_element_type(0) == element::bf16 || (n->get_input_element_type(0) == element::f32 && _enableBF16)) &&
(n->get_input_shape(0)[3] % bf16Factor != 0 || n->get_input_shape(1)[1] % bf16Factor != 0 || n->get_input_shape(3)[3] % bf16Factor != 0)) {
return true;
}
return false;
});
postLPTPassManager.run_passes(nGraphFunc); postLPTPassManager.run_passes(nGraphFunc);
if (!useLpt && _enableSnippets && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) { if (!useLpt && _enableSnippets && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) {
@ -631,9 +654,9 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
} }
} }
static void Transformation(CNNNetwork& clonedNetwork, const bool _enableLPT, const bool _enableSnippets, const bool isLegacyApi) { static void Transformation(CNNNetwork& clonedNetwork, const bool _enableLPT, const bool _enableBF16, const bool _enableSnippets, const bool isLegacyApi) {
auto nGraphFunc = clonedNetwork.getFunction(); auto nGraphFunc = clonedNetwork.getFunction();
TransformationUpToCPUSpecificOpSet(nGraphFunc, _enableLPT, _enableSnippets, isLegacyApi); TransformationUpToCPUSpecificOpSet(nGraphFunc, _enableLPT, _enableBF16, _enableSnippets, isLegacyApi);
ConvertToCPUSpecificOpset(nGraphFunc); ConvertToCPUSpecificOpset(nGraphFunc);
} }
@ -774,7 +797,7 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std
|| engConfig.enableDynamicBatch; || engConfig.enableDynamicBatch;
const bool enableSnippets = !(enableModelCache || enableDynamicBatch || enableBF16); const bool enableSnippets = !(enableModelCache || enableDynamicBatch || enableBF16);
auto nGraphFunc = clonedNetwork.getFunction(); auto nGraphFunc = clonedNetwork.getFunction();
TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableSnippets, isLegacyAPI()); TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableBF16, enableSnippets, isLegacyAPI());
// need to check that all outputs have static shapes // need to check that all outputs have static shapes
// checking that all inputs have static shapes is performed in the common part // checking that all inputs have static shapes is performed in the common part
@ -1022,7 +1045,7 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork& network, const std::ma
|| Config::LPTransformsMode::On == engConfig.lpTransformsMode /* or already enabled */; || Config::LPTransformsMode::On == engConfig.lpTransformsMode /* or already enabled */;
const bool enableSnippets = !(conf.cache_dir.empty() || conf.enableDynamicBatch || (conf.enforceBF16 const bool enableSnippets = !(conf.cache_dir.empty() || conf.enableDynamicBatch || (conf.enforceBF16
&& dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core))); && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)));
Transformation(clonedNetwork, enableLPT, enableSnippets, isLegacyAPI()); Transformation(clonedNetwork, enableLPT, conf.enforceBF16, enableSnippets, isLegacyAPI());
auto ops = clonnedFunction->get_ordered_ops(); auto ops = clonnedFunction->get_ordered_ops();
//Mark removed nodes as supported //Mark removed nodes as supported

View File

@ -0,0 +1,549 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <tuple>
#include <string>
#include <vector>
#include <memory>
#include <debug.h>
#include <shared_test_classes/base/ov_subgraph.hpp>
#include <ngraph_functions/builders.hpp>
#include "common_test_utils/common_utils.hpp"
#include <common_test_utils/ov_tensor_utils.hpp>
#include "functional_test_utils/skip_tests_config.hpp"
#include "test_utils/cpu_test_utils.hpp"
using namespace CPUTestUtils;
using namespace ov::test;
using namespace ngraph::helpers;
namespace CPUSubgraphTestsDefinitions {
typedef std::tuple<
std::vector<InputShape>, // Input shapes
std::vector<ElementType>, // Input precisions
std::vector<ElementType>, // MatMul input #0 precisions
size_t, // pattern type #
std::string // Device name
> MHATuple;
static std::shared_ptr<ov::Model> initMHASubgraph0(std::vector<ov::PartialShape>& inputDynamicShapes, std::vector<ElementType>& inputPrecisions) {
ngraph::ParameterVector ngraphParam;
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[0], inputDynamicShapes[0]);
ngraphParam.push_back(transpose0Param);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[1], inputDynamicShapes[1]);
ngraphParam.push_back(transpose1Param);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[2], inputDynamicShapes[2]);
ngraphParam.push_back(addParam);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[3], inputDynamicShapes[3]);
ngraphParam.push_back(transpose2Param);
std::vector<ov::Shape> constantShapes;
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({1, inputDynamicShapes[0].get_shape()[2], 1, 1}));
constantShapes.push_back(ov::Shape({2}));
constantShapes.push_back(ov::Shape({4}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
std::vector<int64_t> transpose0ConstData = {0, 2, 1, 3};
auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData);
std::vector<int64_t> transpose1ConstData = {0, 2, 3, 1};
auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData);
std::vector<float> mulConstData(ngraph::shape_size(constantShapes[2]));
auto mulConst = ngraph::builder::makeConstant(inputPrecisions[0], constantShapes[2], mulConstData, true);
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(inputDynamicShapes[0].get_shape()[0] *
inputDynamicShapes[0].get_shape()[1] * inputDynamicShapes[0].get_shape()[2]),
-1};
auto reshape0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[3], reshape0ConstData);
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(inputDynamicShapes[0].get_shape()[0]),
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[2]),
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[1]),
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[1])};
auto reshape1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[4], reshape1ConstData);
std::vector<int64_t> transpose2ConstData = {0, 2, 1, 3};
auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[5], transpose2ConstData);
std::vector<int64_t> transpose3ConstData = {0, 2, 1, 3};
auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[6], transpose3ConstData);
float transA = false;
float transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
const auto mul = std::make_shared<ngraph::opset3::Multiply>(transpose1, mulConst);
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, mul, transA, transB);
const auto add = std::make_shared<ngraph::opset3::Add>(matMul0, addParam);
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(reshape1, transpose2, transA, transB);
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(matMul1, transpose3Const);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
}
static std::shared_ptr<ov::Model> initMHASubgraph1(std::vector<ov::PartialShape>& inputDynamicShapes, std::vector<ElementType>& inputPrecisions) {
ngraph::ParameterVector ngraphParam;
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[0], inputDynamicShapes[0]);
ngraphParam.push_back(transpose0Param);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[1], inputDynamicShapes[1]);
ngraphParam.push_back(transpose1Param);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[2], inputDynamicShapes[2]);
ngraphParam.push_back(addParam);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[3], inputDynamicShapes[3]);
ngraphParam.push_back(transpose2Param);
std::vector<ov::Shape> constantShapes;
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({1, inputDynamicShapes[0].get_shape()[2], 1, 1}));
constantShapes.push_back(ov::Shape({2}));
constantShapes.push_back(ov::Shape({4}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
std::vector<int64_t> transpose0ConstData = {0, 2, 1, 3};
auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData);
std::vector<int64_t> transpose1ConstData = {0, 2, 3, 1};
auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData);
std::vector<int64_t> transpose2ConstData = {0, 2, 1, 3};
auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose2ConstData);
std::vector<int64_t> transpose3ConstData = {0, 2, 1, 3};
auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose3ConstData);
float transA = false;
float transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1, transA, transB);
const auto add = std::make_shared<ngraph::opset3::Add>(matMul0, addParam);
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(add, 3);
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, transpose2, transA, transB);
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(matMul1, transpose3Const);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
}
class MHATest : public testing::WithParamInterface<MHATuple>,
virtual public SubgraphBaseTest, public CPUTestsBase {
public:
static std::string getTestCaseName(const testing::TestParamInfo<MHATuple> &obj) {
std::vector<InputShape> inputShapes;
std::vector<ElementType> inputPrecisions;
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::string targetName;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetName) = obj.param;
std::ostringstream results;
results << "IS=(";
for (const auto& shape : inputShapes) {
results << CommonTestUtils::partialShape2str({shape.first}) << "_";
}
results << ")_TS=(";
for (const auto& shape : inputShapes) {
for (const auto& item : shape.second) {
results << CommonTestUtils::vec2str(item) << "_";
}
}
for (int i = 0; i < inputPrecisions.size(); i++) {
results << "InPRC" << std::to_string(i) << "=" << inputPrecisions[i] << "_";
}
results << "patternType=" << patternType;
results << "targetDevice=" << targetName;
return results.str();
}
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& funcInputs = function->inputs();
for (int i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
ov::Tensor tensor;
tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(funcInput.get_element_type(), targetInputStaticShapes[i], 1.0f, 0.5f);
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
}
}
protected:
void SetUp() override {
std::vector<InputShape> inputShapes;
std::vector<ElementType> inputPrecisions;
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
init_input_shapes(inputShapes);
if (patternType == 0) {
function = initMHASubgraph0(inputDynamicShapes, inputPrecisions);
} else if (patternType == 1) {
function = initMHASubgraph1(inputDynamicShapes, inputPrecisions);
} else {
FAIL() << "Unsupported MHA pattern type";
}
// TODO: try better input data initialization to avoid threshold adjustment
// TODO: support different precisions on inputs
if (inputPrecisions[0] == ElementType::bf16) {
abs_threshold = 0.1f;
rel_threshold = 10.f;
configuration.insert({{ InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16, InferenceEngine::PluginConfigParams::YES }});
}
}
};
TEST_P(MHATest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
std::vector<InputShape> inputShapes;
std::vector<ElementType> inputPrecisions;
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
if (inputPrecisions[0] == ElementType::bf16 && !InferenceEngine::with_cpu_x86_bfloat16())
GTEST_SKIP();
if (!InferenceEngine::with_cpu_x86_avx512_core())
GTEST_SKIP();
run();
CheckNumberOfNodesWithType(compiledModel, "MHA", 1);
}
namespace {
std::vector<std::vector<ngraph::Shape>> inputShapes = {
{{2, 8, 16, 64}, {2, 8, 16, 64}, {2, 1, 1, 8}, {2, 8, 16, 64}},
{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
{{2, 64, 16, 80}, {2, 64, 16, 80}, {2, 1, 1, 64}, {2, 64, 16, 80}},
{{3, 96, 16, 64}, {3, 96, 16, 64}, {3, 1, 1, 96}, {3, 96, 16, 64}},
{{2, 192, 16, 160}, {2, 192, 16, 160}, {2, 1, 1, 192}, {2, 192, 16, 160}},
{{2, 4, 16, 8}, {2, 4, 16, 8}, {2, 1, 1, 4}, {2, 4, 16, 8}},
{{1, 204, 13, 212}, {1, 204, 13, 212}, {1, 1, 1, 204}, {1, 204, 13, 212}},
};
std::vector<std::vector<ElementType>> inputPrecisions = {
{ ElementType::f32, ElementType::f32, ElementType::f32, ElementType::f32 },
{ ElementType::bf16, ElementType::bf16, ElementType::bf16, ElementType::bf16 },
};
std::vector<std::vector<ElementType>> matMulIn0Precisions = {
{},
};
std::vector<size_t> patternTypes = {
0, 1
};
INSTANTIATE_TEST_SUITE_P(smoke_MHA, MHATest,
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(matMulIn0Precisions),
::testing::ValuesIn(patternTypes),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
MHATest::getTestCaseName);
} // namespace
static std::shared_ptr<ov::Model> initMHAQuantSubgraph0(std::vector<ov::PartialShape>& inputDynamicShapes, std::vector<ElementType>& inputPrecisions,
std::vector<ElementType>& matMulIn0Precisions) {
ngraph::ParameterVector ngraphParam;
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[0], inputDynamicShapes[0]);
ngraphParam.push_back(transpose0Param);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[1], inputDynamicShapes[1]);
ngraphParam.push_back(transpose1Param);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[2], inputDynamicShapes[2]);
ngraphParam.push_back(addParam);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[3], inputDynamicShapes[3]);
ngraphParam.push_back(transpose2Param);
std::vector<ov::Shape> constantShapes;
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({2}));
constantShapes.push_back(ov::Shape({4}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
std::vector<int64_t> transpose0ConstData = {0, 2, 1, 3};
auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData);
std::vector<int64_t> transpose1ConstData = {0, 2, 3, 1};
auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData);
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(inputDynamicShapes[0].get_shape()[0] *
inputDynamicShapes[0].get_shape()[1] * inputDynamicShapes[0].get_shape()[2]),
-1};
auto reshape0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[2], reshape0ConstData);
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(inputDynamicShapes[0].get_shape()[0]),
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[2]),
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[1]),
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[1])};
auto reshape1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[3], reshape1ConstData);
std::vector<int64_t> transpose2ConstData = {0, 2, 1, 3};
auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[4], transpose2ConstData);
std::vector<int64_t> transpose3ConstData = {0, 2, 1, 3};
auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[5], transpose3ConstData);
float transA = false;
float transB = false;
std::shared_ptr<ov::Node> fakeQuantize0;
if (matMulIn0Precisions[0] == ElementType::u8)
fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f});
else
fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
const auto fakeQuantize1 = ngraph::builder::makeFakeQuantize(transpose1Param, inputPrecisions[1], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
const auto fakeQuantize2 = ngraph::builder::makeFakeQuantize(transpose2Param, inputPrecisions[3], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
std::shared_ptr<ov::Node> fakeQuantize4;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize0, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize1, transpose1Const);
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1, transA, transB);
const auto fakeQuantize3 = ngraph::builder::makeFakeQuantize(matMul0, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
const auto add = std::make_shared<ngraph::opset3::Add>(fakeQuantize3, addParam);
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
if (matMulIn0Precisions[1] == ElementType::u8)
fakeQuantize4 = ngraph::builder::makeFakeQuantize(reshape1, inputPrecisions[0], 256, {}, {0.0f}, {0.255f}, {0.0f}, {0.255f});
else
fakeQuantize4 = ngraph::builder::makeFakeQuantize(reshape1, inputPrecisions[0], 256, {}, {-0.128f}, {0.127f}, {-0.128f}, {0.127f});
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize2, transpose2Const);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(fakeQuantize4, transpose2, transA, transB);
const auto fakeQuantize5 = ngraph::builder::makeFakeQuantize(matMul1, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize5, transpose3Const);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
}
static std::shared_ptr<ov::Model> initMHAQuantSubgraph1(std::vector<ov::PartialShape>& inputDynamicShapes, std::vector<ElementType>& inputPrecisions,
std::vector<ElementType>& matMulIn0Precisions) {
ngraph::ParameterVector ngraphParam;
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[0], inputDynamicShapes[0]);
ngraphParam.push_back(transpose0Param);
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[1], inputDynamicShapes[1]);
ngraphParam.push_back(transpose1Param);
auto addParam = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[2], inputDynamicShapes[2]);
ngraphParam.push_back(addParam);
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[3], inputDynamicShapes[3]);
ngraphParam.push_back(transpose2Param);
std::vector<ov::Shape> constantShapes;
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
constantShapes.push_back(ov::Shape({1}));
std::vector<int64_t> transpose0ConstData = {0, 2, 1, 3};
auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData);
std::vector<int64_t> transpose1ConstData = {0, 2, 3, 1};
auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData);
std::vector<int64_t> transpose2ConstData = {0, 2, 1, 3};
auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[2], transpose2ConstData);
std::vector<int64_t> transpose3ConstData = {0, 2, 1, 3};
auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[3], transpose3ConstData);
std::vector<float> mulConstData(ngraph::shape_size(constantShapes[4]));
auto mulConst = ngraph::builder::makeConstant(inputPrecisions[0], constantShapes[4], mulConstData, true);
float transA = false;
float transB = false;
std::shared_ptr<ov::Node> fakeQuantize0;
if (matMulIn0Precisions[0] == ElementType::u8)
fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f});
else
fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize0, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
const auto fakeQuantize1 = ngraph::builder::makeFakeQuantize(transpose1, inputPrecisions[1], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, fakeQuantize1, transA, transB);
const auto mul = std::make_shared<ngraph::opset3::Multiply>(addParam, mulConst);
const auto add = std::make_shared<ngraph::opset3::Add>(matMul0, mul);
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(add, 3);
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, transpose2, transA, transB);
const auto fakeQuantize2 = ngraph::builder::makeFakeQuantize(matMul1, inputPrecisions[0], 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f});
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize2, transpose3Const);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
}
class MHAQuantTest : public testing::WithParamInterface<MHATuple>,
virtual public SubgraphBaseTest, public CPUTestsBase {
public:
static std::string getTestCaseName(const testing::TestParamInfo<MHATuple> &obj) {
std::vector<InputShape> inputShapes;
std::vector<ElementType> inputPrecisions;
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::string targetName;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetName) = obj.param;
std::ostringstream results;
results << "IS=(";
for (const auto& shape : inputShapes) {
results << CommonTestUtils::partialShape2str({shape.first}) << "_";
}
results << ")_TS=(";
for (const auto& shape : inputShapes) {
for (const auto& item : shape.second) {
results << CommonTestUtils::vec2str(item) << "_";
}
}
for (int i = 0; i < inputPrecisions.size(); i++) {
results << "InPRC" << std::to_string(i) << "=" << inputPrecisions[i] << "_";
}
for (int i = 0; i < matMulIn0Precisions.size(); i++) {
results << "MatMulIn0PRC" << std::to_string(i) << "=" << matMulIn0Precisions[i] << "_";
}
results << "patternType=" << patternType;
results << "targetDevice=" << targetName;
return results.str();
}
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
inputs.clear();
const auto& funcInputs = function->inputs();
for (int i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
ov::Tensor tensor;
if (funcInput.get_element_type().is_real())
tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(funcInput.get_element_type(), targetInputStaticShapes[i], 0.0f, 1.5f);
else
tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i], 255, 0, 1);
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
}
}
protected:
void SetUp() override {
abs_threshold = 0.1f;
std::vector<InputShape> inputShapes;
std::vector<ElementType> inputPrecisions;
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
init_input_shapes(inputShapes);
if (patternType == 0) {
function = initMHAQuantSubgraph0(inputDynamicShapes, inputPrecisions, matMulIn0Precisions);
} else if (patternType == 1) {
function = initMHAQuantSubgraph1(inputDynamicShapes, inputPrecisions, matMulIn0Precisions);
} else {
FAIL() << "Unsupported MHA pattern type";
}
}
};
TEST_P(MHAQuantTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
std::vector<InputShape> inputShapes;
std::vector<ElementType> inputPrecisions;
std::vector<ElementType> matMulIn0Precisions;
size_t patternType;
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
if (inputPrecisions[0] == ElementType::bf16 && !InferenceEngine::with_cpu_x86_bfloat16())
GTEST_SKIP();
if (!InferenceEngine::with_cpu_x86_avx512_core_vnni())
GTEST_SKIP();
run();
CheckNumberOfNodesWithType(compiledModel, "MHA", 1);
}
namespace {
std::vector<std::vector<ngraph::Shape>> inputShapesQuant = {
{{2, 7, 16, 9}, {2, 7, 16, 9}, {2, 1, 1, 7}, {2, 7, 16, 9}},
{{2, 8, 16, 64}, {2, 8, 16, 64}, {2, 1, 1, 8}, {2, 8, 16, 64}},
{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
{{2, 64, 16, 80}, {2, 64, 16, 80}, {2, 1, 1, 64}, {2, 64, 16, 80}},
{{3, 96, 16, 64}, {3, 96, 16, 64}, {3, 1, 1, 96}, {3, 96, 16, 64}},
{{2, 192, 16, 160}, {2, 192, 16, 160}, {2, 1, 1, 192}, {2, 192, 16, 160}},
{{2, 4, 16, 8}, {2, 4, 16, 8}, {2, 1, 1, 4}, {2, 4, 16, 8}},
{{1, 204, 13, 212}, {1, 204, 13, 212}, {1, 1, 1, 204}, {1, 204, 13, 212}},
{{1, 207, 13, 211}, {1, 207, 13, 211}, {1, 1, 1, 207}, {1, 207, 13, 211}},
};
std::vector<std::vector<ElementType>> inputPrecisionsQuant = {
{ ElementType::f32, ElementType::f32, ElementType::f32, ElementType::f32 },
};
std::vector<std::vector<ElementType>> matMulIn0PrecisionsQuant = {
{ ElementType::i8, ElementType::i8 },
{ ElementType::i8, ElementType::u8 },
};
std::vector<size_t> patternTypesQuant = {
0, 1
};
INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant, MHAQuantTest,
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapesQuant)),
::testing::ValuesIn(inputPrecisionsQuant),
::testing::ValuesIn(matMulIn0PrecisionsQuant),
::testing::ValuesIn(patternTypesQuant),
::testing::Values(CommonTestUtils::DEVICE_CPU)),
MHAQuantTest::getTestCaseName);
} // namespace
} // namespace CPUSubgraphTestsDefinitions