[CPU] Support MHA optimization (#12936)
This commit is contained in:
parent
0dd1f6e1b0
commit
e7fe00f5f2
@ -97,6 +97,13 @@ INFERENCE_ENGINE_API_CPP(bool) with_cpu_x86_avx512f();
|
|||||||
*/
|
*/
|
||||||
INFERENCE_ENGINE_API_CPP(bool) with_cpu_x86_avx512_core();
|
INFERENCE_ENGINE_API_CPP(bool) with_cpu_x86_avx512_core();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @brief Checks whether CPU supports AVX 512 VNNI capability
|
||||||
|
* @ingroup ie_dev_api_system_conf
|
||||||
|
* @return `True` is AVX512F, AVX512BW, AVX512DQ, AVX512_VNNI instructions are available, `false` otherwise
|
||||||
|
*/
|
||||||
|
INFERENCE_ENGINE_API_CPP(bool) with_cpu_x86_avx512_core_vnni();
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @brief Checks whether CPU supports BFloat16 capability
|
* @brief Checks whether CPU supports BFloat16 capability
|
||||||
* @ingroup ie_dev_api_system_conf
|
* @ingroup ie_dev_api_system_conf
|
||||||
|
@ -41,6 +41,10 @@ bool with_cpu_x86_avx512_core() {
|
|||||||
return get_cpu_info().has(Xbyak::util::Cpu::tAVX512F | Xbyak::util::Cpu::tAVX512DQ | Xbyak::util::Cpu::tAVX512BW);
|
return get_cpu_info().has(Xbyak::util::Cpu::tAVX512F | Xbyak::util::Cpu::tAVX512DQ | Xbyak::util::Cpu::tAVX512BW);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bool with_cpu_x86_avx512_core_vnni() {
|
||||||
|
return with_cpu_x86_avx512_core() && get_cpu_info().has(Xbyak::util::Cpu::tAVX512_VNNI);
|
||||||
|
}
|
||||||
|
|
||||||
bool with_cpu_x86_bfloat16() {
|
bool with_cpu_x86_bfloat16() {
|
||||||
return get_cpu_info().has(Xbyak::util::Cpu::tAVX512_BF16);
|
return get_cpu_info().has(Xbyak::util::Cpu::tAVX512_BF16);
|
||||||
}
|
}
|
||||||
|
@ -197,6 +197,7 @@ const InferenceEngine::details::caseless_unordered_map<std::string, Type> type_t
|
|||||||
{ "Subgraph", Type::Subgraph},
|
{ "Subgraph", Type::Subgraph},
|
||||||
{ "PriorBox", Type::PriorBox},
|
{ "PriorBox", Type::PriorBox},
|
||||||
{ "PriorBoxClustered", Type::PriorBoxClustered},
|
{ "PriorBoxClustered", Type::PriorBoxClustered},
|
||||||
|
{ "MHA", Type::MHA},
|
||||||
};
|
};
|
||||||
|
|
||||||
Type TypeFromName(const std::string& type) {
|
Type TypeFromName(const std::string& type) {
|
||||||
@ -388,6 +389,8 @@ std::string NameFromType(const Type type) {
|
|||||||
return "Reference";
|
return "Reference";
|
||||||
case Type::Subgraph:
|
case Type::Subgraph:
|
||||||
return "Subgraph";
|
return "Subgraph";
|
||||||
|
case Type::MHA:
|
||||||
|
return "MHA";
|
||||||
default:
|
default:
|
||||||
return "Unknown";
|
return "Unknown";
|
||||||
}
|
}
|
||||||
|
@ -108,6 +108,7 @@ enum class Type {
|
|||||||
Subgraph,
|
Subgraph,
|
||||||
PriorBox,
|
PriorBox,
|
||||||
PriorBoxClustered,
|
PriorBoxClustered,
|
||||||
|
MHA
|
||||||
};
|
};
|
||||||
|
|
||||||
enum class Algorithm {
|
enum class Algorithm {
|
||||||
|
@ -7,6 +7,7 @@
|
|||||||
#include "ngraph_transformations/op/leaky_relu.hpp"
|
#include "ngraph_transformations/op/leaky_relu.hpp"
|
||||||
#include "ngraph_transformations/op/power_static.hpp"
|
#include "ngraph_transformations/op/power_static.hpp"
|
||||||
#include "ngraph_transformations/op/swish_cpu.hpp"
|
#include "ngraph_transformations/op/swish_cpu.hpp"
|
||||||
|
#include "ngraph_transformations/op/mha.hpp"
|
||||||
#include "snippets_transformations/op/load_convert.hpp"
|
#include "snippets_transformations/op/load_convert.hpp"
|
||||||
#include "snippets_transformations/op/store_convert.hpp"
|
#include "snippets_transformations/op/store_convert.hpp"
|
||||||
|
|
||||||
@ -44,6 +45,7 @@ std::map<std::string, ngraph::OpSet> Extension::getOpSets() {
|
|||||||
NGRAPH_OP(LeakyReluNode, ov::intel_cpu)
|
NGRAPH_OP(LeakyReluNode, ov::intel_cpu)
|
||||||
NGRAPH_OP(PowerStaticNode, ov::intel_cpu)
|
NGRAPH_OP(PowerStaticNode, ov::intel_cpu)
|
||||||
NGRAPH_OP(SwishNode, ov::intel_cpu)
|
NGRAPH_OP(SwishNode, ov::intel_cpu)
|
||||||
|
NGRAPH_OP(MHANode, ov::intel_cpu)
|
||||||
NGRAPH_OP(LoadConvertSaturation, ov::intel_cpu)
|
NGRAPH_OP(LoadConvertSaturation, ov::intel_cpu)
|
||||||
NGRAPH_OP(LoadConvertTruncation, ov::intel_cpu)
|
NGRAPH_OP(LoadConvertTruncation, ov::intel_cpu)
|
||||||
NGRAPH_OP(StoreConvertSaturation, ov::intel_cpu)
|
NGRAPH_OP(StoreConvertSaturation, ov::intel_cpu)
|
||||||
|
644
src/plugins/intel_cpu/src/ngraph_transformations/mha_fusion.cpp
Normal file
644
src/plugins/intel_cpu/src/ngraph_transformations/mha_fusion.cpp
Normal file
@ -0,0 +1,644 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "mha_fusion.hpp"
|
||||||
|
|
||||||
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
|
#include <ngraph/rt_info.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
#include "op/mha.hpp"
|
||||||
|
|
||||||
|
#include "itt.hpp"
|
||||||
|
|
||||||
|
// TODO: draw pattern
|
||||||
|
ov::intel_cpu::MHAFloatFusion::MHAFloatFusion() {
|
||||||
|
MATCHER_SCOPE(MHAFloatFusion);
|
||||||
|
|
||||||
|
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
|
||||||
|
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
|
||||||
|
auto mul = std::make_shared<ngraph::opset3::Multiply>(transpose1, in2);
|
||||||
|
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, mul);
|
||||||
|
auto add = std::make_shared<ngraph::opset4::Add>(matmul0, in3);
|
||||||
|
auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, in6, true);
|
||||||
|
auto softmax = std::make_shared<ngraph::opset1::Softmax>(reshape0);
|
||||||
|
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softmax, in7, true);
|
||||||
|
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
|
||||||
|
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(reshape1, transpose2);
|
||||||
|
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(matmul1, in10);
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||||
|
auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto transpose0_in = pattern_to_output.at(in0);
|
||||||
|
auto transpose1_in = pattern_to_output.at(in1);
|
||||||
|
auto mul_in1 = pattern_to_output.at(in2);
|
||||||
|
auto add_in1 = pattern_to_output.at(in3);
|
||||||
|
auto transpose2_in = pattern_to_output.at(in8);
|
||||||
|
|
||||||
|
if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (transpose0_in.get_shape().size() != 4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]});
|
||||||
|
if (add_in1.get_shape() != expected_add_shape) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
|
||||||
|
std::vector<float> mul_scales;
|
||||||
|
if (auto mul_node = ngraph::as_type_ptr<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
|
||||||
|
mul_scales = ngraph::as_type_ptr<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
|
||||||
|
|
||||||
|
auto expected_shape = ngraph::Shape({1, transpose0_in.get_shape()[2], 1, 1});
|
||||||
|
if (mul_scales.size() != 1 && mul_node->get_input_shape(1) != expected_shape) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto matmul0_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
|
||||||
|
if (!matmul0_node)
|
||||||
|
return false;
|
||||||
|
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto reshape0_node = ngraph::as_type_ptr<ngraph::opset1::Reshape>(pattern_to_output.at(reshape0).get_node_shared_ptr());
|
||||||
|
if (!reshape0_node)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (auto reshape_pattern = ngraph::as_type_ptr<ngraph::opset4::Constant>(pattern_to_output.at(in6).get_node_shared_ptr())) {
|
||||||
|
if (reshape0_node->get_input_shape(0).size() != 4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> reshapeConstData = {static_cast<int64_t>(reshape0_node->get_input_shape(0)[0] *
|
||||||
|
reshape0_node->get_input_shape(0)[1] *
|
||||||
|
reshape0_node->get_input_shape(0)[2]),
|
||||||
|
-1};
|
||||||
|
|
||||||
|
if (reshape_pattern->cast_vector<int64_t>() != reshapeConstData) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto reshape1_node = ngraph::as_type_ptr<ngraph::opset1::Reshape>(pattern_to_output.at(reshape1).get_node_shared_ptr())) {
|
||||||
|
if (reshape0_node->get_input_shape(0) != reshape1_node->get_output_shape(0)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto softmax_node = ngraph::as_type_ptr<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
|
||||||
|
if (!softmax_node)
|
||||||
|
return false;
|
||||||
|
if (softmax_node->get_axis() != 1)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto matmul1_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
|
||||||
|
if (!matmul1_node)
|
||||||
|
return false;
|
||||||
|
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
bool is_mul_first = true;
|
||||||
|
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
|
||||||
|
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
|
||||||
|
transpose3_node->get_output_element_type(0));
|
||||||
|
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(matmul0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(reshape0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(softmax).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(reshape1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose2).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(matmul1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose3).get_node_shared_ptr(),
|
||||||
|
},
|
||||||
|
mha);
|
||||||
|
|
||||||
|
if (transformation_callback(mha)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ngraph::replace_node(m.get_match_root(), mha);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::intel_cpu::MHAFloatFusion2::MHAFloatFusion2() {
|
||||||
|
MATCHER_SCOPE(MHAFloatFusion2);
|
||||||
|
|
||||||
|
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
|
||||||
|
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
|
||||||
|
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1);
|
||||||
|
auto add = std::make_shared<ngraph::opset4::Add>(matmul0, in3);
|
||||||
|
auto softmax = std::make_shared<ngraph::opset1::Softmax>(add);
|
||||||
|
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
|
||||||
|
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, transpose2);
|
||||||
|
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(matmul1, in10);
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||||
|
auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto transpose0_in = pattern_to_output.at(in0);
|
||||||
|
auto transpose1_in = pattern_to_output.at(in1);
|
||||||
|
auto add_in1 = pattern_to_output.at(in3);
|
||||||
|
auto transpose2_in = pattern_to_output.at(in8);
|
||||||
|
|
||||||
|
if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (transpose0_in.get_shape().size() != 4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]});
|
||||||
|
if (add_in1.get_shape() != expected_add_shape) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
|
||||||
|
auto matmul0_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
|
||||||
|
if (!matmul0_node)
|
||||||
|
return false;
|
||||||
|
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto softmax_node = ngraph::as_type_ptr<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
|
||||||
|
if (!softmax_node)
|
||||||
|
return false;
|
||||||
|
if (softmax_node->get_axis() != 3)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto matmul1_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
|
||||||
|
if (!matmul1_node)
|
||||||
|
return false;
|
||||||
|
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
|
||||||
|
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, std::vector<float>(), false,
|
||||||
|
transpose3_node->get_output_element_type(0));
|
||||||
|
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(matmul0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(softmax).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose2).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(matmul1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose3).get_node_shared_ptr(),
|
||||||
|
},
|
||||||
|
mha);
|
||||||
|
|
||||||
|
if (transformation_callback(mha)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ngraph::replace_node(m.get_match_root(), mha);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::vector<float> simplifyToScale(const std::shared_ptr<ngraph::opset1::FakeQuantize>& fq_node) {
|
||||||
|
auto levels = fq_node->get_levels();
|
||||||
|
auto input_low = ngraph::as_type_ptr<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(1))->cast_vector<float>();
|
||||||
|
auto input_high = ngraph::as_type_ptr<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(2))->cast_vector<float>();
|
||||||
|
auto output_low = ngraph::as_type_ptr<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(3))->cast_vector<float>();
|
||||||
|
auto output_high = ngraph::as_type_ptr<ngraph::opset4::Constant>(fq_node->get_input_node_shared_ptr(4))->cast_vector<float>();
|
||||||
|
|
||||||
|
std::vector<float> cl, ch, isc, ish, osc, osh;
|
||||||
|
for (int i = 0; i < input_low.size(); i++) {
|
||||||
|
cl.push_back(input_low[i]);
|
||||||
|
}
|
||||||
|
for (int i = 0; i < input_high.size(); i++) {
|
||||||
|
ch.push_back(input_high[i]);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < std::max(input_low.size(), input_high.size()); i++) {
|
||||||
|
float il = input_low[input_low.size() == 1 ? 0 : i];
|
||||||
|
float ih = input_high[input_high.size() == 1 ? 0 : i];
|
||||||
|
|
||||||
|
isc.push_back((levels - 1) / (ih - il));
|
||||||
|
ish.push_back(-il * (levels - 1) / (ih - il));
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < std::max(output_low.size(), output_high.size()); i++) {
|
||||||
|
float ol = output_low[output_low.size() == 1 ? 0 : i];
|
||||||
|
float oh = output_high[output_high.size() == 1 ? 0 : i];
|
||||||
|
|
||||||
|
osc.push_back((oh - ol) / (levels - 1));
|
||||||
|
osh.push_back(ol);
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> outScale;
|
||||||
|
|
||||||
|
if (fq_node->get_output_element_type(0) == ngraph::element::u8 &&
|
||||||
|
std::all_of(cl.cbegin(), cl.cend(), [](float val) { return val == 0.0f; }) &&
|
||||||
|
std::all_of(ish.cbegin(), ish.cend(), [](float val) { return val == 0.0f; }) &&
|
||||||
|
std::all_of(osc.cbegin(), osc.cend(), [](float val) { return val == 1.0f; }) &&
|
||||||
|
std::all_of(osh.cbegin(), osh.cend(), [](float val) { return val == 0.0f; })) {
|
||||||
|
outScale = isc;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (fq_node->get_output_element_type(0) == ngraph::element::i8 &&
|
||||||
|
std::all_of(ish.cbegin(), ish.cend(), [](float val) { return std::abs(val - 128.f) < 0.0001f; }) &&
|
||||||
|
std::all_of(osc.cbegin(), osc.cend(), [](float val) { return val == 1.f; }) &&
|
||||||
|
std::all_of(osh.cbegin(), osh.cend(), [](float val) { return std::abs(val + 128.f) < 0.0001f; })) {
|
||||||
|
bool isCropAligned = true;
|
||||||
|
for (int i = 0; i < std::max(cl.size(), isc.size()); i++) {
|
||||||
|
if (std::abs(cl[cl.size() == 1 ? 0 : i] * isc[isc.size() == 1 ? 0 : i] + 128.f) > 0.0001f) {
|
||||||
|
isCropAligned = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < std::max(ch.size(), isc.size()); i++) {
|
||||||
|
if (std::abs(ch[ch.size() == 1 ? 0 : i] * isc[isc.size() == 1 ? 0 : i] - 127.f) > 0.0001f) {
|
||||||
|
isCropAligned = false;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (isCropAligned) {
|
||||||
|
outScale = isc;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return outScale;
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: draw pattern
|
||||||
|
ov::intel_cpu::MHAQuantFusion::MHAQuantFusion() {
|
||||||
|
MATCHER_SCOPE(MHAQuantFusion);
|
||||||
|
|
||||||
|
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in6 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in7 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
|
||||||
|
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
|
||||||
|
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1);
|
||||||
|
auto fakeQuantize0 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul0,
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
|
||||||
|
auto add = std::make_shared<ngraph::opset4::Add>(fakeQuantize0, in3);
|
||||||
|
auto mul = std::make_shared<ngraph::opset3::Multiply>(add, in2);
|
||||||
|
auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(mul, in6, true);
|
||||||
|
auto softmax = std::make_shared<ngraph::opset1::Softmax>(reshape0);
|
||||||
|
auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softmax, in7, true);
|
||||||
|
auto fakeQuantize1 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({reshape1,
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
|
||||||
|
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
|
||||||
|
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(fakeQuantize1, transpose2);
|
||||||
|
auto fakeQuantize2 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul1,
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
|
||||||
|
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(fakeQuantize2, in10);
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||||
|
auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto transpose0_in = pattern_to_output.at(in0);
|
||||||
|
auto transpose1_in = pattern_to_output.at(in1);
|
||||||
|
auto add_in1 = pattern_to_output.at(in3);
|
||||||
|
auto transpose2_in = pattern_to_output.at(in8);
|
||||||
|
|
||||||
|
if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (transpose0_in.get_shape().size() != 4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]});
|
||||||
|
if (add_in1.get_shape() != expected_add_shape) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> mul_scales;
|
||||||
|
if (auto mul_node = ngraph::as_type_ptr<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
|
||||||
|
mul_scales = ngraph::as_type_ptr<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
|
||||||
|
|
||||||
|
auto expected_shape = ngraph::Shape({1, transpose0_in.get_shape()[2], 1, 1});
|
||||||
|
if (mul_scales.size() != 1 && mul_node->get_input_shape(1) != expected_shape) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
|
||||||
|
auto matmul0_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
|
||||||
|
if (!matmul0_node)
|
||||||
|
return false;
|
||||||
|
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
std::vector<float> fq0_scale;
|
||||||
|
auto fq0_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize0).get_node_shared_ptr());
|
||||||
|
if (fq0_node) {
|
||||||
|
fq0_scale = simplifyToScale(fq0_node);
|
||||||
|
if (!fq0_scale.size())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto reshape0_node = ngraph::as_type_ptr<ngraph::opset1::Reshape>(pattern_to_output.at(reshape0).get_node_shared_ptr());
|
||||||
|
if (!reshape0_node)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
if (auto reshape_pattern = ngraph::as_type_ptr<ngraph::opset4::Constant>(pattern_to_output.at(in6).get_node_shared_ptr())) {
|
||||||
|
if (reshape0_node->get_input_shape(0).size() != 4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<int64_t> reshapeConstData = {static_cast<int64_t>(reshape0_node->get_input_shape(0)[0] *
|
||||||
|
reshape0_node->get_input_shape(0)[1] *
|
||||||
|
reshape0_node->get_input_shape(0)[2]),
|
||||||
|
-1};
|
||||||
|
|
||||||
|
if (reshape_pattern->cast_vector<int64_t>() != reshapeConstData) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (auto reshape1_node = ngraph::as_type_ptr<ngraph::opset1::Reshape>(pattern_to_output.at(reshape1).get_node_shared_ptr())) {
|
||||||
|
if (reshape0_node->get_input_shape(0) != reshape1_node->get_output_shape(0)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto softmax_node = ngraph::as_type_ptr<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
|
||||||
|
if (!softmax_node)
|
||||||
|
return false;
|
||||||
|
if (softmax_node->get_axis() != 1)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
std::vector<float> fq1_scale;
|
||||||
|
auto fq1_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr());
|
||||||
|
if (fq1_node) {
|
||||||
|
fq1_scale = simplifyToScale(fq1_node);
|
||||||
|
if (!fq1_scale.size())
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto matmul1_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
|
||||||
|
if (!matmul1_node)
|
||||||
|
return false;
|
||||||
|
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
std::vector<float> fq2_scale;
|
||||||
|
if (auto fq_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize2).get_node_shared_ptr())) {
|
||||||
|
fq2_scale = simplifyToScale(fq_node);
|
||||||
|
if (!fq2_scale.size())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool is_mul_first = false;
|
||||||
|
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
|
||||||
|
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
|
||||||
|
std::vector<float>(), fq0_scale, fq1_scale, fq2_scale,
|
||||||
|
ngraph::element::undefined,
|
||||||
|
fq0_node ? fq0_node->get_output_element_type(0) : ngraph::element::undefined,
|
||||||
|
fq1_node->get_output_element_type(0), transpose3_node->get_output_element_type(0));
|
||||||
|
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(matmul0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(reshape0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(softmax).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(reshape1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose2).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(matmul1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(fakeQuantize2).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose3).get_node_shared_ptr(),
|
||||||
|
},
|
||||||
|
mha);
|
||||||
|
|
||||||
|
if (transformation_callback(mha)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ngraph::replace_node(m.get_match_root(), mha);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: draw pattern
|
||||||
|
ov::intel_cpu::MHAQuantFusion2::MHAQuantFusion2() {
|
||||||
|
MATCHER_SCOPE(MHAQuantFusion2);
|
||||||
|
|
||||||
|
auto in0 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in1 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in2 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in3 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in4 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in5 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in8 = ngraph::pattern::any_input(ngraph::pattern::has_static_shape());
|
||||||
|
auto in9 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto in10 = ngraph::pattern::wrap_type<ngraph::opset4::Constant>();
|
||||||
|
auto transpose0 = std::make_shared<ngraph::opset3::Transpose>(in0, in4);
|
||||||
|
auto transpose1 = std::make_shared<ngraph::opset3::Transpose>(in1, in5);
|
||||||
|
auto fakeQuantize0 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({transpose1,
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
|
||||||
|
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, fakeQuantize0);
|
||||||
|
auto mul = std::make_shared<ngraph::opset3::Multiply>(matmul0, in2);
|
||||||
|
auto add = std::make_shared<ngraph::opset4::Add>(mul, in3);
|
||||||
|
auto softmax = std::make_shared<ngraph::opset1::Softmax>(add);
|
||||||
|
auto transpose2 = std::make_shared<ngraph::opset3::Transpose>(in8, in9);
|
||||||
|
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(softmax, transpose2);
|
||||||
|
auto fakeQuantize1 = ngraph::pattern::wrap_type<ngraph::opset1::FakeQuantize>({matmul1,
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||||
|
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
|
||||||
|
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(fakeQuantize1, in10);
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||||
|
auto& pattern_to_output = m.get_pattern_value_map();
|
||||||
|
auto transpose0_in = pattern_to_output.at(in0);
|
||||||
|
auto transpose1_in = pattern_to_output.at(in1);
|
||||||
|
auto add_in1 = pattern_to_output.at(in3);
|
||||||
|
auto transpose2_in = pattern_to_output.at(in8);
|
||||||
|
|
||||||
|
if (transpose0_in.get_shape() != transpose1_in.get_shape() || transpose0_in.get_shape() != transpose2_in.get_shape()) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (transpose0_in.get_shape().size() != 4) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto expected_add_shape = Shape({transpose0_in.get_shape()[0], 1, 1, transpose0_in.get_shape()[1]});
|
||||||
|
if (add_in1.get_shape() != expected_add_shape) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<float> mul_scales;
|
||||||
|
if (auto mul_node = ngraph::as_type_ptr<ngraph::opset3::Multiply>(pattern_to_output.at(mul).get_node_shared_ptr())) {
|
||||||
|
mul_scales = ngraph::as_type_ptr<ngraph::opset4::Constant>(mul_node->get_input_node_shared_ptr(1))->cast_vector<float>();
|
||||||
|
|
||||||
|
auto expected_shape = ngraph::Shape({1, transpose0_in.get_shape()[2], 1, 1});
|
||||||
|
if (mul_scales.size() != 1 && mul_node->get_input_shape(1) != expected_shape) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in4).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in5).get_node_shared_ptr(), {0, 2, 3, 1})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in9).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
if (!valid_transpose_order(pattern_to_output.at(in10).get_node_shared_ptr(), {0, 2, 1, 3})) return false;
|
||||||
|
|
||||||
|
auto matmul0_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul0).get_node_shared_ptr());
|
||||||
|
if (!matmul0_node)
|
||||||
|
return false;
|
||||||
|
if (matmul0_node->get_transpose_a() || matmul0_node->get_transpose_b())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
std::vector<float> fq0_scale;
|
||||||
|
auto fq0_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize0).get_node_shared_ptr());
|
||||||
|
if (fq0_node) {
|
||||||
|
fq0_scale = simplifyToScale(fq0_node);
|
||||||
|
if (!fq0_scale.size())
|
||||||
|
return false;
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto softmax_node = ngraph::as_type_ptr<ngraph::opset1::Softmax>(pattern_to_output.at(softmax).get_node_shared_ptr());
|
||||||
|
if (!softmax_node)
|
||||||
|
return false;
|
||||||
|
if (softmax_node->get_axis() != 3)
|
||||||
|
return false;
|
||||||
|
|
||||||
|
std::vector<float> fq1_scale;
|
||||||
|
if (auto fq_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
|
||||||
|
fq1_scale = simplifyToScale(fq_node);
|
||||||
|
if (!fq1_scale.size())
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
auto matmul1_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
|
||||||
|
if (!matmul1_node)
|
||||||
|
return false;
|
||||||
|
if (matmul1_node->get_transpose_a() || matmul1_node->get_transpose_b())
|
||||||
|
return false;
|
||||||
|
|
||||||
|
bool is_mul_first = true;
|
||||||
|
auto transpose3_node = pattern_to_output.at(transpose3).get_node_shared_ptr();
|
||||||
|
auto mha = std::make_shared<ov::intel_cpu::MHANode>(transpose0_in, transpose1_in, add_in1, transpose2_in, mul_scales, is_mul_first,
|
||||||
|
fq0_scale, std::vector<float>(), std::vector<float>(), fq1_scale,
|
||||||
|
fq0_node->get_output_element_type(0), ngraph::element::undefined, ngraph::element::undefined,
|
||||||
|
transpose3_node->get_output_element_type(0));
|
||||||
|
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||||
|
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(matmul0).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(softmax).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose2).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(matmul1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(),
|
||||||
|
pattern_to_output.at(transpose3).get_node_shared_ptr(),
|
||||||
|
},
|
||||||
|
mha);
|
||||||
|
|
||||||
|
if (transformation_callback(mha)) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
ngraph::replace_node(m.get_match_root(), mha);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(transpose3, matcher_name);
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
@ -0,0 +1,64 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
#include <ngraph/opsets/opset4.hpp>
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace intel_cpu {
|
||||||
|
|
||||||
|
class MHAFusionBase : public ngraph::pass::MatcherPass {
|
||||||
|
protected:
|
||||||
|
bool valid_transpose_order(const std::shared_ptr<ngraph::Node>& node, const std::vector<int64_t>& expected_order) {
|
||||||
|
if (auto transpose_pattern = ngraph::as_type_ptr<ngraph::opset4::Constant>(node)) {
|
||||||
|
if (transpose_pattern->cast_vector<int64_t>() != expected_order) {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
class MHAFloatFusion: public MHAFusionBase {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("MHAFloatFusion", "0");
|
||||||
|
MHAFloatFusion();
|
||||||
|
};
|
||||||
|
|
||||||
|
class MHAFloatFusion2: public MHAFusionBase {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("MHAFloatFusion2", "0");
|
||||||
|
MHAFloatFusion2();
|
||||||
|
};
|
||||||
|
|
||||||
|
class MHAQuantFusion: public MHAFusionBase {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("MHAQuantFusion", "0");
|
||||||
|
MHAQuantFusion();
|
||||||
|
};
|
||||||
|
|
||||||
|
class MHAQuantFusion2: public MHAFusionBase {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("MHAQuantFusion2", "0");
|
||||||
|
MHAQuantFusion2();
|
||||||
|
};
|
||||||
|
|
||||||
|
class MHAFusion : public ngraph::pass::GraphRewrite {
|
||||||
|
public:
|
||||||
|
OPENVINO_RTTI("MHAFusion", "0");
|
||||||
|
MHAFusion() {
|
||||||
|
add_matcher<MHAFloatFusion>();
|
||||||
|
add_matcher<MHAFloatFusion2>();
|
||||||
|
add_matcher<MHAQuantFusion>();
|
||||||
|
add_matcher<MHAQuantFusion2>();
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace intel_cpu
|
||||||
|
} // namespace ov
|
108
src/plugins/intel_cpu/src/ngraph_transformations/op/mha.cpp
Normal file
108
src/plugins/intel_cpu/src/ngraph_transformations/op/mha.cpp
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "mha.hpp"
|
||||||
|
#include "../itt.hpp"
|
||||||
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
|
#include <matmul_shape_inference.hpp>
|
||||||
|
|
||||||
|
ov::intel_cpu::MHANode::MHANode(const ngraph::Output<ngraph::Node> &in0,
|
||||||
|
const ngraph::Output<ngraph::Node> &in1,
|
||||||
|
const ngraph::Output<ngraph::Node> &in2,
|
||||||
|
const ngraph::Output<ngraph::Node> &in3,
|
||||||
|
const std::vector<float> &mul_scales,
|
||||||
|
bool is_mul_first,
|
||||||
|
const ngraph::element::Type output_type)
|
||||||
|
: Op({in0, in1, in2, in3}), m_output_type(output_type) {
|
||||||
|
this->mul_scales = mul_scales;
|
||||||
|
this->is_mul_first = is_mul_first;
|
||||||
|
this->fq0_output_type = ngraph::element::undefined;
|
||||||
|
this->fq1_output_type = ngraph::element::undefined;
|
||||||
|
this->fq2_output_type = ngraph::element::undefined;
|
||||||
|
validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
ov::intel_cpu::MHANode::MHANode(const ngraph::Output<ngraph::Node> &in0,
|
||||||
|
const ngraph::Output<ngraph::Node> &in1,
|
||||||
|
const ngraph::Output<ngraph::Node> &in2,
|
||||||
|
const ngraph::Output<ngraph::Node> &in3,
|
||||||
|
const std::vector<float> &mul_scales,
|
||||||
|
bool is_mul_first,
|
||||||
|
const std::vector<float> &fq_scales0,
|
||||||
|
const std::vector<float> &fq_scales1,
|
||||||
|
const std::vector<float> &fq_scales2,
|
||||||
|
const std::vector<float> &fq_scales3,
|
||||||
|
const ngraph::element::Type fq0_output_type,
|
||||||
|
const ngraph::element::Type fq1_output_type,
|
||||||
|
const ngraph::element::Type fq2_output_type,
|
||||||
|
const ngraph::element::Type output_type)
|
||||||
|
: Op({in0, in1, in2, in3}), m_output_type(output_type) {
|
||||||
|
this->mul_scales = mul_scales;
|
||||||
|
this->is_mul_first = is_mul_first;
|
||||||
|
this->fq_scales0 = fq_scales0;
|
||||||
|
this->fq_scales1 = fq_scales1;
|
||||||
|
this->fq_scales2 = fq_scales2;
|
||||||
|
this->fq_scales3 = fq_scales3;
|
||||||
|
this->fq0_output_type = fq0_output_type;
|
||||||
|
this->fq1_output_type = fq1_output_type;
|
||||||
|
this->fq2_output_type = fq2_output_type;
|
||||||
|
validate_and_infer_types();
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> ov::intel_cpu::MHANode::clone_with_new_inputs(const ngraph::OutputVector& new_args) const {
|
||||||
|
INTERNAL_OP_SCOPE(MHANode_clone_with_new_inputs);
|
||||||
|
check_new_args_count(this, new_args);
|
||||||
|
return std::make_shared<ov::intel_cpu::MHANode>(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3),
|
||||||
|
mul_scales, is_mul_first, fq_scales0, fq_scales1, fq_scales2, fq_scales3,
|
||||||
|
fq0_output_type, fq1_output_type, fq2_output_type, m_output_type);
|
||||||
|
}
|
||||||
|
|
||||||
|
void ov::intel_cpu::MHANode::validate_and_infer_types() {
|
||||||
|
INTERNAL_OP_SCOPE(MHANode_validate_and_infer_types);
|
||||||
|
|
||||||
|
auto transpose = [](const ov::Shape& shape, const std::vector<size_t>& order) -> ov::Shape {
|
||||||
|
std::vector<size_t> new_shape(shape.size());
|
||||||
|
for (int i = 0; i < shape.size(); i++) {
|
||||||
|
new_shape[i] = shape[order[i]];
|
||||||
|
}
|
||||||
|
return new_shape;
|
||||||
|
};
|
||||||
|
|
||||||
|
const auto matmul0_shape0 = transpose(get_input_partial_shape(0).get_shape(), {0, 2, 1, 3});
|
||||||
|
const auto matmul0_shape1 = transpose(get_input_partial_shape(1).get_shape(), {0, 2, 3, 1});
|
||||||
|
|
||||||
|
auto matmul0_in0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul0_shape0);
|
||||||
|
auto matmul0_in1 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul0_shape1);
|
||||||
|
auto matmul0 = std::make_shared<ngraph::opset3::MatMul>(matmul0_in0, matmul0_in1);
|
||||||
|
|
||||||
|
std::vector<ov::PartialShape> matmul0_input_shapes = {matmul0_shape0, matmul0_shape1};
|
||||||
|
std::vector<ov::PartialShape> matmul0_output_shapes = {ov::PartialShape{}};
|
||||||
|
|
||||||
|
shape_infer(matmul0.get(), matmul0_input_shapes, matmul0_output_shapes);
|
||||||
|
|
||||||
|
const auto matmul1_shape0 = matmul0_output_shapes[0];
|
||||||
|
const auto matmul1_shape1 = transpose(get_input_partial_shape(3).get_shape(), {0, 2, 1, 3});
|
||||||
|
|
||||||
|
auto matmul1_in0 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul1_shape0);
|
||||||
|
auto matmul1_in1 = std::make_shared<ngraph::opset3::Parameter>(ngraph::element::f32, matmul1_shape1);
|
||||||
|
auto matmul1 = std::make_shared<ngraph::opset3::MatMul>(matmul1_in0, matmul1_in1);
|
||||||
|
|
||||||
|
std::vector<ov::PartialShape> matmul1_input_shapes = {matmul1_shape0, matmul1_shape1};
|
||||||
|
std::vector<ov::PartialShape> matmul1_output_shapes = {ov::PartialShape{}};
|
||||||
|
|
||||||
|
shape_infer(matmul1.get(), matmul1_input_shapes, matmul1_output_shapes);
|
||||||
|
|
||||||
|
const auto output_shape = transpose(matmul1_output_shapes[0].get_shape(), {0, 2, 1, 3});
|
||||||
|
|
||||||
|
set_output_type(
|
||||||
|
0,
|
||||||
|
m_output_type == ngraph::element::undefined || m_output_type == ngraph::element::dynamic ? get_input_element_type(0) : m_output_type,
|
||||||
|
output_shape);
|
||||||
|
}
|
||||||
|
|
||||||
|
bool ov::intel_cpu::MHANode::visit_attributes(ngraph::AttributeVisitor &visitor) {
|
||||||
|
INTERNAL_OP_SCOPE(MHANode_visit_attributes);
|
||||||
|
visitor.on_attribute("out-type", m_output_type);
|
||||||
|
return true;
|
||||||
|
}
|
94
src/plugins/intel_cpu/src/ngraph_transformations/op/mha.hpp
Normal file
94
src/plugins/intel_cpu/src/ngraph_transformations/op/mha.hpp
Normal file
@ -0,0 +1,94 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <ngraph/op/op.hpp>
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace intel_cpu {
|
||||||
|
|
||||||
|
class MHANode : public ngraph::op::Op {
|
||||||
|
public:
|
||||||
|
OPENVINO_OP("MHA", "cpu_plugin_opset");
|
||||||
|
|
||||||
|
MHANode() = default;
|
||||||
|
|
||||||
|
MHANode(const ngraph::Output<ngraph::Node> &in0,
|
||||||
|
const ngraph::Output<ngraph::Node> &in1,
|
||||||
|
const ngraph::Output<ngraph::Node> &in2,
|
||||||
|
const ngraph::Output<ngraph::Node> &in3,
|
||||||
|
const std::vector<float> &mul_scales,
|
||||||
|
bool is_mul_first,
|
||||||
|
const ngraph::element::Type output_type);
|
||||||
|
|
||||||
|
MHANode(const ngraph::Output<ngraph::Node> &in0,
|
||||||
|
const ngraph::Output<ngraph::Node> &in1,
|
||||||
|
const ngraph::Output<ngraph::Node> &in2,
|
||||||
|
const ngraph::Output<ngraph::Node> &in3,
|
||||||
|
const std::vector<float> &mul_scales,
|
||||||
|
bool is_mul_first,
|
||||||
|
const std::vector<float> &fq_scales0,
|
||||||
|
const std::vector<float> &fq_scales1,
|
||||||
|
const std::vector<float> &fq_scales2,
|
||||||
|
const std::vector<float> &fq_scales3,
|
||||||
|
const ngraph::element::Type fq0_output_type,
|
||||||
|
const ngraph::element::Type fq1_output_type,
|
||||||
|
const ngraph::element::Type fq2_output_type,
|
||||||
|
const ngraph::element::Type output_type);
|
||||||
|
|
||||||
|
void validate_and_infer_types() override;
|
||||||
|
|
||||||
|
bool visit_attributes(ngraph::AttributeVisitor &visitor) override;
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Node> clone_with_new_inputs(const ngraph::OutputVector &new_args) const override;
|
||||||
|
|
||||||
|
ngraph::element::Type get_output_type() const { return m_output_type; }
|
||||||
|
|
||||||
|
const std::vector<float>& get_mul_scales() const {
|
||||||
|
return mul_scales;
|
||||||
|
}
|
||||||
|
|
||||||
|
const std::vector<float>& get_fq_scales0() const {
|
||||||
|
return fq_scales0;
|
||||||
|
}
|
||||||
|
const std::vector<float>& get_fq_scales1() const {
|
||||||
|
return fq_scales1;
|
||||||
|
}
|
||||||
|
const std::vector<float>& get_fq_scales2() const {
|
||||||
|
return fq_scales2;
|
||||||
|
}
|
||||||
|
const std::vector<float>& get_fq_scales3() const {
|
||||||
|
return fq_scales3;
|
||||||
|
}
|
||||||
|
|
||||||
|
bool get_is_mul_first() const {
|
||||||
|
return is_mul_first;
|
||||||
|
}
|
||||||
|
|
||||||
|
ngraph::element::Type get_fq0_output_type() const {
|
||||||
|
return fq0_output_type;
|
||||||
|
}
|
||||||
|
ngraph::element::Type get_fq1_output_type() const {
|
||||||
|
return fq1_output_type;
|
||||||
|
}
|
||||||
|
ngraph::element::Type get_fq2_output_type() const {
|
||||||
|
return fq2_output_type;
|
||||||
|
}
|
||||||
|
|
||||||
|
private:
|
||||||
|
ngraph::element::Type m_output_type;
|
||||||
|
std::vector<float> mul_scales;
|
||||||
|
bool is_mul_first;
|
||||||
|
std::vector<float> fq_scales0;
|
||||||
|
std::vector<float> fq_scales1;
|
||||||
|
std::vector<float> fq_scales2;
|
||||||
|
std::vector<float> fq_scales3;
|
||||||
|
ngraph::element::Type fq0_output_type;
|
||||||
|
ngraph::element::Type fq1_output_type;
|
||||||
|
ngraph::element::Type fq2_output_type;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace intel_cpu
|
||||||
|
} // namespace ov
|
1404
src/plugins/intel_cpu/src/nodes/mha.cpp
Normal file
1404
src/plugins/intel_cpu/src/nodes/mha.cpp
Normal file
File diff suppressed because it is too large
Load Diff
243
src/plugins/intel_cpu/src/nodes/mha.h
Normal file
243
src/plugins/intel_cpu/src/nodes/mha.h
Normal file
@ -0,0 +1,243 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <node.h>
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <cpu/x64/brgemm/brgemm.hpp>
|
||||||
|
#include <cpu/x64/matmul/brgemm_matmul_copy_utils.hpp>
|
||||||
|
#include <cpu/x64/matmul/brgemm_matmul_utils.hpp>
|
||||||
|
#include <cpu/x64/amx_tile_configure.hpp>
|
||||||
|
|
||||||
|
namespace ov {
|
||||||
|
namespace intel_cpu {
|
||||||
|
namespace node {
|
||||||
|
|
||||||
|
struct jit_mul_add_softmax_compile_params {
|
||||||
|
InferenceEngine::Precision src_prc;
|
||||||
|
InferenceEngine::Precision dst_prc;
|
||||||
|
size_t work_amount;
|
||||||
|
bool with_mul_scales;
|
||||||
|
bool is_mul_first;
|
||||||
|
bool with_scales0;
|
||||||
|
bool broadcast_scales0;
|
||||||
|
bool with_scales1;
|
||||||
|
bool broadcast_scales1;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct jit_mul_add_softmax_call_args {
|
||||||
|
const void *p_in0;
|
||||||
|
const void *p_mul_in1;
|
||||||
|
const void *p_add_in1;
|
||||||
|
void *p_out;
|
||||||
|
void *p_buffer;
|
||||||
|
const void *p_scales0;
|
||||||
|
const void *p_scales1;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct jit_uni_mul_add_softmax_kernel {
|
||||||
|
void (*ker_)(const jit_mul_add_softmax_call_args*);
|
||||||
|
|
||||||
|
void operator()(const jit_mul_add_softmax_call_args* call_args) {
|
||||||
|
assert(ker_);
|
||||||
|
ker_(call_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit jit_uni_mul_add_softmax_kernel(const jit_mul_add_softmax_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
|
||||||
|
virtual ~jit_uni_mul_add_softmax_kernel() {}
|
||||||
|
|
||||||
|
virtual void create_ker() = 0;
|
||||||
|
|
||||||
|
jit_mul_add_softmax_compile_params jcp_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct jit_convert_reorder_compile_params {
|
||||||
|
InferenceEngine::Precision src_prc;
|
||||||
|
InferenceEngine::Precision dst_prc;
|
||||||
|
size_t inner_work_amount;
|
||||||
|
bool with_scales;
|
||||||
|
bool broadcast_scales;
|
||||||
|
size_t src_stride;
|
||||||
|
size_t dst_stride;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct jit_convert_reorder_call_args {
|
||||||
|
const void *p_in;
|
||||||
|
void *p_out;
|
||||||
|
const void *p_scales;
|
||||||
|
size_t outter_work_amount;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct jit_uni_convert_reorder_kernel {
|
||||||
|
void (*ker_)(const jit_convert_reorder_call_args*);
|
||||||
|
|
||||||
|
void operator()(const jit_convert_reorder_call_args* call_args) {
|
||||||
|
assert(ker_);
|
||||||
|
ker_(call_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit jit_uni_convert_reorder_kernel(const jit_convert_reorder_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
|
||||||
|
virtual ~jit_uni_convert_reorder_kernel() {}
|
||||||
|
|
||||||
|
virtual void create_ker() = 0;
|
||||||
|
|
||||||
|
jit_convert_reorder_compile_params jcp_;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct jit_convert_transpose_compile_params {
|
||||||
|
InferenceEngine::Precision src_prc;
|
||||||
|
InferenceEngine::Precision dst_prc;
|
||||||
|
size_t inner_work_amount;
|
||||||
|
size_t outter_work_amount;
|
||||||
|
bool with_scales;
|
||||||
|
bool broadcast_scales;
|
||||||
|
size_t inner_src_stride;
|
||||||
|
size_t outter_src_stride;
|
||||||
|
size_t outter_dst_stride;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct jit_convert_transpose_call_args {
|
||||||
|
const void *p_in;
|
||||||
|
void *p_out;
|
||||||
|
const void *p_scales;
|
||||||
|
};
|
||||||
|
|
||||||
|
struct jit_uni_convert_transpose_kernel {
|
||||||
|
void (*ker_)(const jit_convert_transpose_call_args*);
|
||||||
|
|
||||||
|
void operator()(const jit_convert_transpose_call_args* call_args) {
|
||||||
|
assert(ker_);
|
||||||
|
ker_(call_args);
|
||||||
|
}
|
||||||
|
|
||||||
|
explicit jit_uni_convert_transpose_kernel(const jit_convert_transpose_compile_params& jcp) : ker_(nullptr), jcp_(jcp) {}
|
||||||
|
virtual ~jit_uni_convert_transpose_kernel() {}
|
||||||
|
|
||||||
|
virtual void create_ker() = 0;
|
||||||
|
|
||||||
|
jit_convert_transpose_compile_params jcp_;
|
||||||
|
};
|
||||||
|
|
||||||
|
#define MHA_BRGEMM_KERNELS_NUM 8
|
||||||
|
|
||||||
|
class MHA : public Node {
|
||||||
|
public:
|
||||||
|
MHA(const std::shared_ptr<ngraph::Node>& op, const dnnl::engine& eng, WeightsSharing::Ptr &cache);
|
||||||
|
|
||||||
|
void getSupportedDescriptors() override {};
|
||||||
|
void initSupportedPrimitiveDescriptors() override;
|
||||||
|
void execute(dnnl::stream strm) override;
|
||||||
|
bool created() const override;
|
||||||
|
|
||||||
|
static bool isSupportedOperation(const std::shared_ptr<const ngraph::Node>& op, std::string& errorMessage) noexcept;
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void executeDynamicImpl(dnnl::stream strm) override;
|
||||||
|
void prepareParams() override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
struct brgemmCtx {
|
||||||
|
size_t M, N, K, LDA, LDB, LDC;
|
||||||
|
dnnl_data_type_t dt_in0, dt_in1;
|
||||||
|
char palette[64];
|
||||||
|
bool is_with_amx;
|
||||||
|
bool is_with_comp;
|
||||||
|
float beta;
|
||||||
|
};
|
||||||
|
|
||||||
|
template <typename in1_type>
|
||||||
|
void mhaImpl();
|
||||||
|
|
||||||
|
void init_brgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel, bool use_amx);
|
||||||
|
void init_brgemm_copy_a(std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_a_t>& brgCopyKernel,
|
||||||
|
size_t K, size_t K_blk, size_t K_tail, size_t LDA, dnnl_data_type_t dt_in0);
|
||||||
|
void init_brgemm_copy_b(std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t>& brgCopyKernel,
|
||||||
|
size_t N, size_t N_blk, size_t N_tail, size_t LDB, size_t K, bool is_with_amx, dnnl_data_type_t dt_in0, dnnl_data_type_t dt_in1);
|
||||||
|
|
||||||
|
void callBrgemm(brgemmCtx& ctx, std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t>& brgKernel,
|
||||||
|
const void* pin0, const void* pin1, void* pout, void* wsp);
|
||||||
|
|
||||||
|
size_t getBrgIdx(size_t mIdx, size_t kIdx, size_t nIdx) {
|
||||||
|
return mIdx * 4 + kIdx * 2 + nIdx;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::vector<InferenceEngine::Precision> inputPrecisions;
|
||||||
|
InferenceEngine::Precision accPrecision0;
|
||||||
|
InferenceEngine::Precision accPrecision1;
|
||||||
|
|
||||||
|
VectorDims dimsTranspose0In0;
|
||||||
|
VectorDims dimsTranspose1In0;
|
||||||
|
VectorDims dimsMulIn1;
|
||||||
|
VectorDims dimsAddIn1;
|
||||||
|
VectorDims dimsTranspose2In0;
|
||||||
|
VectorDims dimsOut;
|
||||||
|
|
||||||
|
VectorDims strTranspose0In0;
|
||||||
|
VectorDims strTranspose1In0;
|
||||||
|
VectorDims strMulIn1;
|
||||||
|
VectorDims strAddIn1;
|
||||||
|
VectorDims strTranspose2In0;
|
||||||
|
VectorDims strOut;
|
||||||
|
|
||||||
|
VectorDims dimsMatMul0In0;
|
||||||
|
VectorDims dimsMatMul0In1;
|
||||||
|
VectorDims dimsMatMul0Out;
|
||||||
|
VectorDims dimsMatMul1In1;
|
||||||
|
VectorDims dimsMatMul1Out;
|
||||||
|
|
||||||
|
size_t batch0, batch1;
|
||||||
|
size_t M, M_blk, M_tail;
|
||||||
|
size_t K0, K0_blk, K0_tail, N0, N0_blk, N0_tail;
|
||||||
|
size_t K1, K1_blk, K1_tail, N1, N1_blk, N1_tail;
|
||||||
|
|
||||||
|
size_t bufferMatMul0In0Size;
|
||||||
|
size_t bufferMatMul0In1Size;
|
||||||
|
size_t bufferMatMul0OutSize;
|
||||||
|
size_t bufferMatMul1In1Size;
|
||||||
|
size_t bufferMatMul1OutSize;
|
||||||
|
size_t bufferCompensation0Size;
|
||||||
|
size_t bufferCompensation1Size;
|
||||||
|
size_t wsp_size_per_thread = 4 * 1024;
|
||||||
|
|
||||||
|
std::vector<uint8_t> bufferMatMul0In0;
|
||||||
|
std::vector<uint8_t> bufferMatMul0In1;
|
||||||
|
std::vector<uint8_t> bufferMatMul0Out;
|
||||||
|
std::vector<uint8_t> bufferMatMul1In1;
|
||||||
|
std::vector<uint8_t> bufferMatMul1Out;
|
||||||
|
std::vector<int32_t> bufferCompensation0;
|
||||||
|
std::vector<int32_t> bufferCompensation1;
|
||||||
|
std::vector<size_t> wsp;
|
||||||
|
|
||||||
|
bool isMulFirst;
|
||||||
|
InferenceEngine::Precision fqPrc2;
|
||||||
|
|
||||||
|
std::vector<float> mulScales;
|
||||||
|
std::vector<float> fqScales0;
|
||||||
|
std::vector<float> fqScales1;
|
||||||
|
std::vector<float> fqScales2;
|
||||||
|
std::vector<float> fqScales3;
|
||||||
|
|
||||||
|
size_t brg0VnniFactor;
|
||||||
|
brgemmCtx brgCtxs0[MHA_BRGEMM_KERNELS_NUM];
|
||||||
|
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> brgKernels0[MHA_BRGEMM_KERNELS_NUM];
|
||||||
|
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_a_t> brgCopyAKernel0;
|
||||||
|
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t> brgCopyBKernel0;
|
||||||
|
|
||||||
|
size_t brg1VnniFactor;
|
||||||
|
brgemmCtx brgCtxs1[MHA_BRGEMM_KERNELS_NUM];
|
||||||
|
std::unique_ptr<dnnl::impl::cpu::x64::brgemm_kernel_t> brgKernels1[MHA_BRGEMM_KERNELS_NUM];
|
||||||
|
std::unique_ptr<dnnl::impl::cpu::x64::matmul::jit_brgemm_matmul_copy_b_t> brgCopyBKernel1;
|
||||||
|
|
||||||
|
std::unique_ptr<jit_uni_mul_add_softmax_kernel> mulAddSoftmaxKernel;
|
||||||
|
std::unique_ptr<jit_uni_convert_reorder_kernel> convertReorderKernel;
|
||||||
|
std::unique_ptr<jit_uni_convert_transpose_kernel> convertTransposeKernel;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace node
|
||||||
|
} // namespace intel_cpu
|
||||||
|
} // namespace ov
|
@ -88,6 +88,7 @@
|
|||||||
#include "nodes/priorbox.h"
|
#include "nodes/priorbox.h"
|
||||||
#include "nodes/priorbox_clustered.h"
|
#include "nodes/priorbox_clustered.h"
|
||||||
#include "nodes/eye.h"
|
#include "nodes/eye.h"
|
||||||
|
#include "nodes/mha.h"
|
||||||
|
|
||||||
namespace ov {
|
namespace ov {
|
||||||
namespace intel_cpu {
|
namespace intel_cpu {
|
||||||
@ -188,6 +189,7 @@ Node::NodesFactory::NodesFactory()
|
|||||||
INTEL_CPU_NODE(PriorBox, Type::PriorBox);
|
INTEL_CPU_NODE(PriorBox, Type::PriorBox);
|
||||||
INTEL_CPU_NODE(PriorBoxClustered, Type::PriorBoxClustered);
|
INTEL_CPU_NODE(PriorBoxClustered, Type::PriorBoxClustered);
|
||||||
INTEL_CPU_NODE(Eye, Type::Eye);
|
INTEL_CPU_NODE(Eye, Type::Eye);
|
||||||
|
INTEL_CPU_NODE(MHA, Type::MHA);
|
||||||
}
|
}
|
||||||
|
|
||||||
#undef INTEL_CPU_NODE
|
#undef INTEL_CPU_NODE
|
||||||
|
@ -87,6 +87,7 @@
|
|||||||
#include <transformations/op_conversions/convert_roi_align_v3_to_v9.hpp>
|
#include <transformations/op_conversions/convert_roi_align_v3_to_v9.hpp>
|
||||||
#include <transformations/op_conversions/softsign_decomposition.hpp>
|
#include <transformations/op_conversions/softsign_decomposition.hpp>
|
||||||
#include "transformations/op_conversions/eye_decomposition.hpp"
|
#include "transformations/op_conversions/eye_decomposition.hpp"
|
||||||
|
#include "ngraph_transformations/mha_fusion.hpp"
|
||||||
|
|
||||||
#include <ngraph/opsets/opset1.hpp>
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
#include <ngraph/opsets/opset2.hpp>
|
#include <ngraph/opsets/opset2.hpp>
|
||||||
@ -118,6 +119,7 @@
|
|||||||
#include "nodes/mvn.h"
|
#include "nodes/mvn.h"
|
||||||
#include "nodes/fake_quantize.h"
|
#include "nodes/fake_quantize.h"
|
||||||
#include "nodes/normalize.h"
|
#include "nodes/normalize.h"
|
||||||
|
#include "nodes/mha.h"
|
||||||
#include "ngraph_transformations/convert_to_cpu_specific_opset.hpp"
|
#include "ngraph_transformations/convert_to_cpu_specific_opset.hpp"
|
||||||
#include "ngraph_transformations/move_eltwise_up_data_movement.hpp"
|
#include "ngraph_transformations/move_eltwise_up_data_movement.hpp"
|
||||||
#include "transformations/smart_reshape/smart_reshape.hpp"
|
#include "transformations/smart_reshape/smart_reshape.hpp"
|
||||||
@ -249,7 +251,7 @@ Engine::~Engine() {
|
|||||||
executorManager()->clear("CPUCallbackExecutor");
|
executorManager()->clear("CPUCallbackExecutor");
|
||||||
}
|
}
|
||||||
|
|
||||||
static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function> nGraphFunc, const bool _enableLPT,
|
static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function> nGraphFunc, const bool _enableLPT, const bool _enableBF16,
|
||||||
const bool _enableSnippets, const bool isLegacyApi) {
|
const bool _enableSnippets, const bool isLegacyApi) {
|
||||||
ngraph::pass::Manager manager;
|
ngraph::pass::Manager manager;
|
||||||
manager.set_per_pass_validation(false);
|
manager.set_per_pass_validation(false);
|
||||||
@ -600,6 +602,27 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
|
|||||||
});
|
});
|
||||||
|
|
||||||
postLPTPassManager.register_pass<ngraph::pass::ConstantFolding>();
|
postLPTPassManager.register_pass<ngraph::pass::ConstantFolding>();
|
||||||
|
|
||||||
|
// Snippets may brake MHA patterns so the fusion has to performed before
|
||||||
|
postLPTPassManager.register_pass<MHAFusion>();
|
||||||
|
postLPTPassManager.get_pass_config()->set_callback<MHAFloatFusion, MHAFloatFusion2,
|
||||||
|
MHAQuantFusion, MHAQuantFusion2>([_enableBF16](const std::shared_ptr<const ov::Node>& n) -> bool {
|
||||||
|
std::string errorMessage;
|
||||||
|
|
||||||
|
if (!node::MHA::isSupportedOperation(n, errorMessage))
|
||||||
|
return true;
|
||||||
|
|
||||||
|
// Implementation calls AMX BF16 brgemm only for tensors with K and N aligned on 2, otherwise fallbacks on vector impl
|
||||||
|
// Vector madd BF16 instruction on SPR has reduced performance on HW level, which results in overall perf degradation
|
||||||
|
size_t bf16Factor = 2;
|
||||||
|
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16_amx_bf16) &&
|
||||||
|
(n->get_input_element_type(0) == element::bf16 || (n->get_input_element_type(0) == element::f32 && _enableBF16)) &&
|
||||||
|
(n->get_input_shape(0)[3] % bf16Factor != 0 || n->get_input_shape(1)[1] % bf16Factor != 0 || n->get_input_shape(3)[3] % bf16Factor != 0)) {
|
||||||
|
return true;
|
||||||
|
}
|
||||||
|
|
||||||
|
return false;
|
||||||
|
});
|
||||||
postLPTPassManager.run_passes(nGraphFunc);
|
postLPTPassManager.run_passes(nGraphFunc);
|
||||||
|
|
||||||
if (!useLpt && _enableSnippets && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) {
|
if (!useLpt && _enableSnippets && dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2)) {
|
||||||
@ -631,9 +654,9 @@ static void TransformationUpToCPUSpecificOpSet(std::shared_ptr<ngraph::Function>
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static void Transformation(CNNNetwork& clonedNetwork, const bool _enableLPT, const bool _enableSnippets, const bool isLegacyApi) {
|
static void Transformation(CNNNetwork& clonedNetwork, const bool _enableLPT, const bool _enableBF16, const bool _enableSnippets, const bool isLegacyApi) {
|
||||||
auto nGraphFunc = clonedNetwork.getFunction();
|
auto nGraphFunc = clonedNetwork.getFunction();
|
||||||
TransformationUpToCPUSpecificOpSet(nGraphFunc, _enableLPT, _enableSnippets, isLegacyApi);
|
TransformationUpToCPUSpecificOpSet(nGraphFunc, _enableLPT, _enableBF16, _enableSnippets, isLegacyApi);
|
||||||
ConvertToCPUSpecificOpset(nGraphFunc);
|
ConvertToCPUSpecificOpset(nGraphFunc);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -774,7 +797,7 @@ Engine::LoadExeNetworkImpl(const InferenceEngine::CNNNetwork &network, const std
|
|||||||
|| engConfig.enableDynamicBatch;
|
|| engConfig.enableDynamicBatch;
|
||||||
const bool enableSnippets = !(enableModelCache || enableDynamicBatch || enableBF16);
|
const bool enableSnippets = !(enableModelCache || enableDynamicBatch || enableBF16);
|
||||||
auto nGraphFunc = clonedNetwork.getFunction();
|
auto nGraphFunc = clonedNetwork.getFunction();
|
||||||
TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableSnippets, isLegacyAPI());
|
TransformationUpToCPUSpecificOpSet(nGraphFunc, enableLPT, enableBF16, enableSnippets, isLegacyAPI());
|
||||||
|
|
||||||
// need to check that all outputs have static shapes
|
// need to check that all outputs have static shapes
|
||||||
// checking that all inputs have static shapes is performed in the common part
|
// checking that all inputs have static shapes is performed in the common part
|
||||||
@ -1022,7 +1045,7 @@ QueryNetworkResult Engine::QueryNetwork(const CNNNetwork& network, const std::ma
|
|||||||
|| Config::LPTransformsMode::On == engConfig.lpTransformsMode /* or already enabled */;
|
|| Config::LPTransformsMode::On == engConfig.lpTransformsMode /* or already enabled */;
|
||||||
const bool enableSnippets = !(conf.cache_dir.empty() || conf.enableDynamicBatch || (conf.enforceBF16
|
const bool enableSnippets = !(conf.cache_dir.empty() || conf.enableDynamicBatch || (conf.enforceBF16
|
||||||
&& dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)));
|
&& dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core)));
|
||||||
Transformation(clonedNetwork, enableLPT, enableSnippets, isLegacyAPI());
|
Transformation(clonedNetwork, enableLPT, conf.enforceBF16, enableSnippets, isLegacyAPI());
|
||||||
auto ops = clonnedFunction->get_ordered_ops();
|
auto ops = clonnedFunction->get_ordered_ops();
|
||||||
|
|
||||||
//Mark removed nodes as supported
|
//Mark removed nodes as supported
|
||||||
|
@ -0,0 +1,549 @@
|
|||||||
|
// Copyright (C) 2018-2022 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <tuple>
|
||||||
|
#include <string>
|
||||||
|
#include <vector>
|
||||||
|
#include <memory>
|
||||||
|
#include <debug.h>
|
||||||
|
#include <shared_test_classes/base/ov_subgraph.hpp>
|
||||||
|
#include <ngraph_functions/builders.hpp>
|
||||||
|
#include "common_test_utils/common_utils.hpp"
|
||||||
|
#include <common_test_utils/ov_tensor_utils.hpp>
|
||||||
|
#include "functional_test_utils/skip_tests_config.hpp"
|
||||||
|
#include "test_utils/cpu_test_utils.hpp"
|
||||||
|
|
||||||
|
using namespace CPUTestUtils;
|
||||||
|
using namespace ov::test;
|
||||||
|
using namespace ngraph::helpers;
|
||||||
|
|
||||||
|
namespace CPUSubgraphTestsDefinitions {
|
||||||
|
|
||||||
|
typedef std::tuple<
|
||||||
|
std::vector<InputShape>, // Input shapes
|
||||||
|
std::vector<ElementType>, // Input precisions
|
||||||
|
std::vector<ElementType>, // MatMul input #0 precisions
|
||||||
|
size_t, // pattern type #
|
||||||
|
std::string // Device name
|
||||||
|
> MHATuple;
|
||||||
|
|
||||||
|
static std::shared_ptr<ov::Model> initMHASubgraph0(std::vector<ov::PartialShape>& inputDynamicShapes, std::vector<ElementType>& inputPrecisions) {
|
||||||
|
ngraph::ParameterVector ngraphParam;
|
||||||
|
|
||||||
|
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[0], inputDynamicShapes[0]);
|
||||||
|
ngraphParam.push_back(transpose0Param);
|
||||||
|
|
||||||
|
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[1], inputDynamicShapes[1]);
|
||||||
|
ngraphParam.push_back(transpose1Param);
|
||||||
|
|
||||||
|
auto addParam = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[2], inputDynamicShapes[2]);
|
||||||
|
ngraphParam.push_back(addParam);
|
||||||
|
|
||||||
|
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[3], inputDynamicShapes[3]);
|
||||||
|
ngraphParam.push_back(transpose2Param);
|
||||||
|
|
||||||
|
std::vector<ov::Shape> constantShapes;
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({1, inputDynamicShapes[0].get_shape()[2], 1, 1}));
|
||||||
|
constantShapes.push_back(ov::Shape({2}));
|
||||||
|
constantShapes.push_back(ov::Shape({4}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose0ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose1ConstData = {0, 2, 3, 1};
|
||||||
|
auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData);
|
||||||
|
|
||||||
|
std::vector<float> mulConstData(ngraph::shape_size(constantShapes[2]));
|
||||||
|
auto mulConst = ngraph::builder::makeConstant(inputPrecisions[0], constantShapes[2], mulConstData, true);
|
||||||
|
|
||||||
|
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(inputDynamicShapes[0].get_shape()[0] *
|
||||||
|
inputDynamicShapes[0].get_shape()[1] * inputDynamicShapes[0].get_shape()[2]),
|
||||||
|
-1};
|
||||||
|
auto reshape0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[3], reshape0ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(inputDynamicShapes[0].get_shape()[0]),
|
||||||
|
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[2]),
|
||||||
|
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[1]),
|
||||||
|
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[1])};
|
||||||
|
auto reshape1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[4], reshape1ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose2ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[5], transpose2ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose3ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[6], transpose3ConstData);
|
||||||
|
|
||||||
|
float transA = false;
|
||||||
|
float transB = false;
|
||||||
|
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
|
||||||
|
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
|
||||||
|
const auto mul = std::make_shared<ngraph::opset3::Multiply>(transpose1, mulConst);
|
||||||
|
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, mul, transA, transB);
|
||||||
|
const auto add = std::make_shared<ngraph::opset3::Add>(matMul0, addParam);
|
||||||
|
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
|
||||||
|
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
|
||||||
|
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
|
||||||
|
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
|
||||||
|
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(reshape1, transpose2, transA, transB);
|
||||||
|
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(matMul1, transpose3Const);
|
||||||
|
|
||||||
|
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
|
||||||
|
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::shared_ptr<ov::Model> initMHASubgraph1(std::vector<ov::PartialShape>& inputDynamicShapes, std::vector<ElementType>& inputPrecisions) {
|
||||||
|
ngraph::ParameterVector ngraphParam;
|
||||||
|
|
||||||
|
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[0], inputDynamicShapes[0]);
|
||||||
|
ngraphParam.push_back(transpose0Param);
|
||||||
|
|
||||||
|
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[1], inputDynamicShapes[1]);
|
||||||
|
ngraphParam.push_back(transpose1Param);
|
||||||
|
|
||||||
|
auto addParam = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[2], inputDynamicShapes[2]);
|
||||||
|
ngraphParam.push_back(addParam);
|
||||||
|
|
||||||
|
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[3], inputDynamicShapes[3]);
|
||||||
|
ngraphParam.push_back(transpose2Param);
|
||||||
|
|
||||||
|
std::vector<ov::Shape> constantShapes;
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({1, inputDynamicShapes[0].get_shape()[2], 1, 1}));
|
||||||
|
constantShapes.push_back(ov::Shape({2}));
|
||||||
|
constantShapes.push_back(ov::Shape({4}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose0ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose1ConstData = {0, 2, 3, 1};
|
||||||
|
auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose2ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose2ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose3ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose3ConstData);
|
||||||
|
|
||||||
|
float transA = false;
|
||||||
|
float transB = false;
|
||||||
|
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
|
||||||
|
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
|
||||||
|
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1, transA, transB);
|
||||||
|
const auto add = std::make_shared<ngraph::opset3::Add>(matMul0, addParam);
|
||||||
|
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(add, 3);
|
||||||
|
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
|
||||||
|
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, transpose2, transA, transB);
|
||||||
|
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(matMul1, transpose3Const);
|
||||||
|
|
||||||
|
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
|
||||||
|
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
|
||||||
|
}
|
||||||
|
|
||||||
|
class MHATest : public testing::WithParamInterface<MHATuple>,
|
||||||
|
virtual public SubgraphBaseTest, public CPUTestsBase {
|
||||||
|
public:
|
||||||
|
static std::string getTestCaseName(const testing::TestParamInfo<MHATuple> &obj) {
|
||||||
|
std::vector<InputShape> inputShapes;
|
||||||
|
std::vector<ElementType> inputPrecisions;
|
||||||
|
std::vector<ElementType> matMulIn0Precisions;
|
||||||
|
size_t patternType;
|
||||||
|
std::string targetName;
|
||||||
|
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetName) = obj.param;
|
||||||
|
std::ostringstream results;
|
||||||
|
|
||||||
|
results << "IS=(";
|
||||||
|
for (const auto& shape : inputShapes) {
|
||||||
|
results << CommonTestUtils::partialShape2str({shape.first}) << "_";
|
||||||
|
}
|
||||||
|
results << ")_TS=(";
|
||||||
|
for (const auto& shape : inputShapes) {
|
||||||
|
for (const auto& item : shape.second) {
|
||||||
|
results << CommonTestUtils::vec2str(item) << "_";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < inputPrecisions.size(); i++) {
|
||||||
|
results << "InPRC" << std::to_string(i) << "=" << inputPrecisions[i] << "_";
|
||||||
|
}
|
||||||
|
results << "patternType=" << patternType;
|
||||||
|
results << "targetDevice=" << targetName;
|
||||||
|
|
||||||
|
return results.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
|
||||||
|
inputs.clear();
|
||||||
|
const auto& funcInputs = function->inputs();
|
||||||
|
for (int i = 0; i < funcInputs.size(); ++i) {
|
||||||
|
const auto& funcInput = funcInputs[i];
|
||||||
|
ov::Tensor tensor;
|
||||||
|
tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(funcInput.get_element_type(), targetInputStaticShapes[i], 1.0f, 0.5f);
|
||||||
|
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
std::vector<InputShape> inputShapes;
|
||||||
|
std::vector<ElementType> inputPrecisions;
|
||||||
|
std::vector<ElementType> matMulIn0Precisions;
|
||||||
|
size_t patternType;
|
||||||
|
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
|
||||||
|
|
||||||
|
init_input_shapes(inputShapes);
|
||||||
|
|
||||||
|
if (patternType == 0) {
|
||||||
|
function = initMHASubgraph0(inputDynamicShapes, inputPrecisions);
|
||||||
|
} else if (patternType == 1) {
|
||||||
|
function = initMHASubgraph1(inputDynamicShapes, inputPrecisions);
|
||||||
|
} else {
|
||||||
|
FAIL() << "Unsupported MHA pattern type";
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: try better input data initialization to avoid threshold adjustment
|
||||||
|
// TODO: support different precisions on inputs
|
||||||
|
if (inputPrecisions[0] == ElementType::bf16) {
|
||||||
|
abs_threshold = 0.1f;
|
||||||
|
rel_threshold = 10.f;
|
||||||
|
|
||||||
|
configuration.insert({{ InferenceEngine::PluginConfigParams::KEY_ENFORCE_BF16, InferenceEngine::PluginConfigParams::YES }});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(MHATest, CompareWithRefs) {
|
||||||
|
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||||
|
|
||||||
|
std::vector<InputShape> inputShapes;
|
||||||
|
std::vector<ElementType> inputPrecisions;
|
||||||
|
std::vector<ElementType> matMulIn0Precisions;
|
||||||
|
size_t patternType;
|
||||||
|
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
|
||||||
|
|
||||||
|
if (inputPrecisions[0] == ElementType::bf16 && !InferenceEngine::with_cpu_x86_bfloat16())
|
||||||
|
GTEST_SKIP();
|
||||||
|
|
||||||
|
if (!InferenceEngine::with_cpu_x86_avx512_core())
|
||||||
|
GTEST_SKIP();
|
||||||
|
|
||||||
|
run();
|
||||||
|
CheckNumberOfNodesWithType(compiledModel, "MHA", 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::vector<std::vector<ngraph::Shape>> inputShapes = {
|
||||||
|
{{2, 8, 16, 64}, {2, 8, 16, 64}, {2, 1, 1, 8}, {2, 8, 16, 64}},
|
||||||
|
{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
|
||||||
|
{{2, 64, 16, 80}, {2, 64, 16, 80}, {2, 1, 1, 64}, {2, 64, 16, 80}},
|
||||||
|
{{3, 96, 16, 64}, {3, 96, 16, 64}, {3, 1, 1, 96}, {3, 96, 16, 64}},
|
||||||
|
{{2, 192, 16, 160}, {2, 192, 16, 160}, {2, 1, 1, 192}, {2, 192, 16, 160}},
|
||||||
|
{{2, 4, 16, 8}, {2, 4, 16, 8}, {2, 1, 1, 4}, {2, 4, 16, 8}},
|
||||||
|
{{1, 204, 13, 212}, {1, 204, 13, 212}, {1, 1, 1, 204}, {1, 204, 13, 212}},
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<ElementType>> inputPrecisions = {
|
||||||
|
{ ElementType::f32, ElementType::f32, ElementType::f32, ElementType::f32 },
|
||||||
|
{ ElementType::bf16, ElementType::bf16, ElementType::bf16, ElementType::bf16 },
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<ElementType>> matMulIn0Precisions = {
|
||||||
|
{},
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<size_t> patternTypes = {
|
||||||
|
0, 1
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_MHA, MHATest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)),
|
||||||
|
::testing::ValuesIn(inputPrecisions),
|
||||||
|
::testing::ValuesIn(matMulIn0Precisions),
|
||||||
|
::testing::ValuesIn(patternTypes),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||||
|
MHATest::getTestCaseName);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
static std::shared_ptr<ov::Model> initMHAQuantSubgraph0(std::vector<ov::PartialShape>& inputDynamicShapes, std::vector<ElementType>& inputPrecisions,
|
||||||
|
std::vector<ElementType>& matMulIn0Precisions) {
|
||||||
|
ngraph::ParameterVector ngraphParam;
|
||||||
|
|
||||||
|
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[0], inputDynamicShapes[0]);
|
||||||
|
ngraphParam.push_back(transpose0Param);
|
||||||
|
|
||||||
|
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[1], inputDynamicShapes[1]);
|
||||||
|
ngraphParam.push_back(transpose1Param);
|
||||||
|
|
||||||
|
auto addParam = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[2], inputDynamicShapes[2]);
|
||||||
|
ngraphParam.push_back(addParam);
|
||||||
|
|
||||||
|
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[3], inputDynamicShapes[3]);
|
||||||
|
ngraphParam.push_back(transpose2Param);
|
||||||
|
|
||||||
|
std::vector<ov::Shape> constantShapes;
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({2}));
|
||||||
|
constantShapes.push_back(ov::Shape({4}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose0ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose1ConstData = {0, 2, 3, 1};
|
||||||
|
auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(inputDynamicShapes[0].get_shape()[0] *
|
||||||
|
inputDynamicShapes[0].get_shape()[1] * inputDynamicShapes[0].get_shape()[2]),
|
||||||
|
-1};
|
||||||
|
auto reshape0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[2], reshape0ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(inputDynamicShapes[0].get_shape()[0]),
|
||||||
|
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[2]),
|
||||||
|
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[1]),
|
||||||
|
static_cast<int64_t>(inputDynamicShapes[0].get_shape()[1])};
|
||||||
|
auto reshape1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[3], reshape1ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose2ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[4], transpose2ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose3ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[5], transpose3ConstData);
|
||||||
|
|
||||||
|
float transA = false;
|
||||||
|
float transB = false;
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Node> fakeQuantize0;
|
||||||
|
if (matMulIn0Precisions[0] == ElementType::u8)
|
||||||
|
fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f});
|
||||||
|
else
|
||||||
|
fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||||
|
|
||||||
|
const auto fakeQuantize1 = ngraph::builder::makeFakeQuantize(transpose1Param, inputPrecisions[1], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||||
|
const auto fakeQuantize2 = ngraph::builder::makeFakeQuantize(transpose2Param, inputPrecisions[3], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Node> fakeQuantize4;
|
||||||
|
|
||||||
|
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize0, transpose0Const);
|
||||||
|
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize1, transpose1Const);
|
||||||
|
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, transpose1, transA, transB);
|
||||||
|
const auto fakeQuantize3 = ngraph::builder::makeFakeQuantize(matMul0, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||||
|
const auto add = std::make_shared<ngraph::opset3::Add>(fakeQuantize3, addParam);
|
||||||
|
const auto reshape0 = std::make_shared<ngraph::opset1::Reshape>(add, reshape0Const, true);
|
||||||
|
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(reshape0, 1);
|
||||||
|
const auto reshape1 = std::make_shared<ngraph::opset1::Reshape>(softMax, reshape1Const, true);
|
||||||
|
if (matMulIn0Precisions[1] == ElementType::u8)
|
||||||
|
fakeQuantize4 = ngraph::builder::makeFakeQuantize(reshape1, inputPrecisions[0], 256, {}, {0.0f}, {0.255f}, {0.0f}, {0.255f});
|
||||||
|
else
|
||||||
|
fakeQuantize4 = ngraph::builder::makeFakeQuantize(reshape1, inputPrecisions[0], 256, {}, {-0.128f}, {0.127f}, {-0.128f}, {0.127f});
|
||||||
|
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize2, transpose2Const);
|
||||||
|
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(fakeQuantize4, transpose2, transA, transB);
|
||||||
|
const auto fakeQuantize5 = ngraph::builder::makeFakeQuantize(matMul1, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||||
|
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize5, transpose3Const);
|
||||||
|
|
||||||
|
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
|
||||||
|
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
|
||||||
|
}
|
||||||
|
|
||||||
|
static std::shared_ptr<ov::Model> initMHAQuantSubgraph1(std::vector<ov::PartialShape>& inputDynamicShapes, std::vector<ElementType>& inputPrecisions,
|
||||||
|
std::vector<ElementType>& matMulIn0Precisions) {
|
||||||
|
ngraph::ParameterVector ngraphParam;
|
||||||
|
|
||||||
|
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[0], inputDynamicShapes[0]);
|
||||||
|
ngraphParam.push_back(transpose0Param);
|
||||||
|
|
||||||
|
auto transpose1Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[1], inputDynamicShapes[1]);
|
||||||
|
ngraphParam.push_back(transpose1Param);
|
||||||
|
|
||||||
|
auto addParam = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[2], inputDynamicShapes[2]);
|
||||||
|
ngraphParam.push_back(addParam);
|
||||||
|
|
||||||
|
auto transpose2Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[3], inputDynamicShapes[3]);
|
||||||
|
ngraphParam.push_back(transpose2Param);
|
||||||
|
|
||||||
|
std::vector<ov::Shape> constantShapes;
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({inputDynamicShapes[0].get_shape().size()}));
|
||||||
|
constantShapes.push_back(ov::Shape({1}));
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose0ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose0Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[0], transpose0ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose1ConstData = {0, 2, 3, 1};
|
||||||
|
auto transpose1Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[1], transpose1ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose2ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose2Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[2], transpose2ConstData);
|
||||||
|
|
||||||
|
std::vector<int64_t> transpose3ConstData = {0, 2, 1, 3};
|
||||||
|
auto transpose3Const = ngraph::builder::makeConstant(ElementType::i64, constantShapes[3], transpose3ConstData);
|
||||||
|
|
||||||
|
std::vector<float> mulConstData(ngraph::shape_size(constantShapes[4]));
|
||||||
|
auto mulConst = ngraph::builder::makeConstant(inputPrecisions[0], constantShapes[4], mulConstData, true);
|
||||||
|
|
||||||
|
float transA = false;
|
||||||
|
float transB = false;
|
||||||
|
|
||||||
|
std::shared_ptr<ov::Node> fakeQuantize0;
|
||||||
|
if (matMulIn0Precisions[0] == ElementType::u8)
|
||||||
|
fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f});
|
||||||
|
else
|
||||||
|
fakeQuantize0 = ngraph::builder::makeFakeQuantize(transpose0Param, inputPrecisions[0], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||||
|
|
||||||
|
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize0, transpose0Const);
|
||||||
|
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
|
||||||
|
const auto fakeQuantize1 = ngraph::builder::makeFakeQuantize(transpose1, inputPrecisions[1], 256, {}, {-1.28f}, {1.27f}, {-1.28f}, {1.27f});
|
||||||
|
const auto matMul0 = std::make_shared<ngraph::opset3::MatMul>(transpose0, fakeQuantize1, transA, transB);
|
||||||
|
const auto mul = std::make_shared<ngraph::opset3::Multiply>(addParam, mulConst);
|
||||||
|
const auto add = std::make_shared<ngraph::opset3::Add>(matMul0, mul);
|
||||||
|
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(add, 3);
|
||||||
|
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
|
||||||
|
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, transpose2, transA, transB);
|
||||||
|
const auto fakeQuantize2 = ngraph::builder::makeFakeQuantize(matMul1, inputPrecisions[0], 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f});
|
||||||
|
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize2, transpose3Const);
|
||||||
|
|
||||||
|
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
|
||||||
|
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
|
||||||
|
}
|
||||||
|
|
||||||
|
class MHAQuantTest : public testing::WithParamInterface<MHATuple>,
|
||||||
|
virtual public SubgraphBaseTest, public CPUTestsBase {
|
||||||
|
public:
|
||||||
|
static std::string getTestCaseName(const testing::TestParamInfo<MHATuple> &obj) {
|
||||||
|
std::vector<InputShape> inputShapes;
|
||||||
|
std::vector<ElementType> inputPrecisions;
|
||||||
|
std::vector<ElementType> matMulIn0Precisions;
|
||||||
|
size_t patternType;
|
||||||
|
std::string targetName;
|
||||||
|
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetName) = obj.param;
|
||||||
|
std::ostringstream results;
|
||||||
|
|
||||||
|
results << "IS=(";
|
||||||
|
for (const auto& shape : inputShapes) {
|
||||||
|
results << CommonTestUtils::partialShape2str({shape.first}) << "_";
|
||||||
|
}
|
||||||
|
results << ")_TS=(";
|
||||||
|
for (const auto& shape : inputShapes) {
|
||||||
|
for (const auto& item : shape.second) {
|
||||||
|
results << CommonTestUtils::vec2str(item) << "_";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for (int i = 0; i < inputPrecisions.size(); i++) {
|
||||||
|
results << "InPRC" << std::to_string(i) << "=" << inputPrecisions[i] << "_";
|
||||||
|
}
|
||||||
|
for (int i = 0; i < matMulIn0Precisions.size(); i++) {
|
||||||
|
results << "MatMulIn0PRC" << std::to_string(i) << "=" << matMulIn0Precisions[i] << "_";
|
||||||
|
}
|
||||||
|
results << "patternType=" << patternType;
|
||||||
|
results << "targetDevice=" << targetName;
|
||||||
|
|
||||||
|
return results.str();
|
||||||
|
}
|
||||||
|
|
||||||
|
void generate_inputs(const std::vector<ngraph::Shape>& targetInputStaticShapes) override {
|
||||||
|
inputs.clear();
|
||||||
|
const auto& funcInputs = function->inputs();
|
||||||
|
for (int i = 0; i < funcInputs.size(); ++i) {
|
||||||
|
const auto& funcInput = funcInputs[i];
|
||||||
|
ov::Tensor tensor;
|
||||||
|
if (funcInput.get_element_type().is_real())
|
||||||
|
tensor = ov::test::utils::create_and_fill_tensor_normal_distribution(funcInput.get_element_type(), targetInputStaticShapes[i], 0.0f, 1.5f);
|
||||||
|
else
|
||||||
|
tensor = ov::test::utils::create_and_fill_tensor(funcInput.get_element_type(), targetInputStaticShapes[i], 255, 0, 1);
|
||||||
|
|
||||||
|
|
||||||
|
inputs.insert({funcInput.get_node_shared_ptr(), tensor});
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
protected:
|
||||||
|
void SetUp() override {
|
||||||
|
abs_threshold = 0.1f;
|
||||||
|
|
||||||
|
std::vector<InputShape> inputShapes;
|
||||||
|
std::vector<ElementType> inputPrecisions;
|
||||||
|
std::vector<ElementType> matMulIn0Precisions;
|
||||||
|
size_t patternType;
|
||||||
|
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
|
||||||
|
|
||||||
|
init_input_shapes(inputShapes);
|
||||||
|
|
||||||
|
if (patternType == 0) {
|
||||||
|
function = initMHAQuantSubgraph0(inputDynamicShapes, inputPrecisions, matMulIn0Precisions);
|
||||||
|
} else if (patternType == 1) {
|
||||||
|
function = initMHAQuantSubgraph1(inputDynamicShapes, inputPrecisions, matMulIn0Precisions);
|
||||||
|
} else {
|
||||||
|
FAIL() << "Unsupported MHA pattern type";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(MHAQuantTest, CompareWithRefs) {
|
||||||
|
SKIP_IF_CURRENT_TEST_IS_DISABLED()
|
||||||
|
|
||||||
|
std::vector<InputShape> inputShapes;
|
||||||
|
std::vector<ElementType> inputPrecisions;
|
||||||
|
std::vector<ElementType> matMulIn0Precisions;
|
||||||
|
size_t patternType;
|
||||||
|
std::tie(inputShapes, inputPrecisions, matMulIn0Precisions, patternType, targetDevice) = this->GetParam();
|
||||||
|
|
||||||
|
if (inputPrecisions[0] == ElementType::bf16 && !InferenceEngine::with_cpu_x86_bfloat16())
|
||||||
|
GTEST_SKIP();
|
||||||
|
|
||||||
|
if (!InferenceEngine::with_cpu_x86_avx512_core_vnni())
|
||||||
|
GTEST_SKIP();
|
||||||
|
|
||||||
|
run();
|
||||||
|
CheckNumberOfNodesWithType(compiledModel, "MHA", 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
|
||||||
|
std::vector<std::vector<ngraph::Shape>> inputShapesQuant = {
|
||||||
|
{{2, 7, 16, 9}, {2, 7, 16, 9}, {2, 1, 1, 7}, {2, 7, 16, 9}},
|
||||||
|
{{2, 8, 16, 64}, {2, 8, 16, 64}, {2, 1, 1, 8}, {2, 8, 16, 64}},
|
||||||
|
{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
|
||||||
|
{{2, 64, 16, 80}, {2, 64, 16, 80}, {2, 1, 1, 64}, {2, 64, 16, 80}},
|
||||||
|
{{3, 96, 16, 64}, {3, 96, 16, 64}, {3, 1, 1, 96}, {3, 96, 16, 64}},
|
||||||
|
{{2, 192, 16, 160}, {2, 192, 16, 160}, {2, 1, 1, 192}, {2, 192, 16, 160}},
|
||||||
|
{{2, 4, 16, 8}, {2, 4, 16, 8}, {2, 1, 1, 4}, {2, 4, 16, 8}},
|
||||||
|
{{1, 204, 13, 212}, {1, 204, 13, 212}, {1, 1, 1, 204}, {1, 204, 13, 212}},
|
||||||
|
{{1, 207, 13, 211}, {1, 207, 13, 211}, {1, 1, 1, 207}, {1, 207, 13, 211}},
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<ElementType>> inputPrecisionsQuant = {
|
||||||
|
{ ElementType::f32, ElementType::f32, ElementType::f32, ElementType::f32 },
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<std::vector<ElementType>> matMulIn0PrecisionsQuant = {
|
||||||
|
{ ElementType::i8, ElementType::i8 },
|
||||||
|
{ ElementType::i8, ElementType::u8 },
|
||||||
|
};
|
||||||
|
|
||||||
|
std::vector<size_t> patternTypesQuant = {
|
||||||
|
0, 1
|
||||||
|
};
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant, MHAQuantTest,
|
||||||
|
::testing::Combine(
|
||||||
|
::testing::ValuesIn(static_shapes_to_test_representation(inputShapesQuant)),
|
||||||
|
::testing::ValuesIn(inputPrecisionsQuant),
|
||||||
|
::testing::ValuesIn(matMulIn0PrecisionsQuant),
|
||||||
|
::testing::ValuesIn(patternTypesQuant),
|
||||||
|
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||||
|
MHAQuantTest::getTestCaseName);
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
} // namespace CPUSubgraphTestsDefinitions
|
Loading…
Reference in New Issue
Block a user