[CPU] Quantized MHA extension for SmoothQuant (#17906)

This commit is contained in:
Edward Shogulin 2023-06-07 15:31:06 +01:00 committed by GitHub
parent 2547301fa7
commit 655c21adf1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 40 additions and 24 deletions

View File

@ -8,6 +8,7 @@
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include "transformations/cpu_opset/x64/op/mha.hpp"
#include "simplify_fakequantize.hpp"
@ -470,7 +471,8 @@ ov::intel_cpu::MHAQuantFusion2::MHAQuantFusion2() {
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(fakeQuantize1, in10);
auto in11 = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{ matmul1, fakeQuantize1 });
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(in11, in10);
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
auto& pattern_to_output = m.get_pattern_value_map();
@ -532,10 +534,13 @@ ov::intel_cpu::MHAQuantFusion2::MHAQuantFusion2() {
return false;
std::vector<float> fq1_scale;
if (auto fq_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
fq1_scale = simplifyToScale(fq_node);
if (!fq1_scale.size())
return false;
const bool fakeQuantize1Exists = pattern_to_output.find(fakeQuantize1) != pattern_to_output.end();
if (fakeQuantize1Exists) {
if (auto fq_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
fq1_scale = simplifyToScale(fq_node);
if (!fq1_scale.size())
return false;
}
}
auto matmul1_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
@ -551,19 +556,23 @@ ov::intel_cpu::MHAQuantFusion2::MHAQuantFusion2() {
fq0_node->get_output_element_type(0), ngraph::element::undefined, ngraph::element::undefined,
transpose3_node->get_output_element_type(0));
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
pattern_to_output.at(transpose1).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
pattern_to_output.at(matmul0).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(softmax).get_node_shared_ptr(),
pattern_to_output.at(transpose2).get_node_shared_ptr(),
pattern_to_output.at(matmul1).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(),
pattern_to_output.at(transpose3).get_node_shared_ptr(),
},
mha);
std::vector<std::shared_ptr<Node>> merged = {
pattern_to_output.at(transpose0).get_node_shared_ptr(),
pattern_to_output.at(transpose1).get_node_shared_ptr(),
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
pattern_to_output.at(matmul0).get_node_shared_ptr(),
pattern_to_output.at(mul).get_node_shared_ptr(),
pattern_to_output.at(add).get_node_shared_ptr(),
pattern_to_output.at(softmax).get_node_shared_ptr(),
pattern_to_output.at(transpose2).get_node_shared_ptr(),
pattern_to_output.at(matmul1).get_node_shared_ptr(),
pattern_to_output.at(transpose3).get_node_shared_ptr(),
};
if (fakeQuantize1Exists) {
merged.push_back(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr());
}
ngraph::copy_runtime_info(merged, mha);
if (transformation_callback(mha)) {
return false;

View File

@ -372,8 +372,10 @@ static std::shared_ptr<ov::Model> initMHAQuantSubgraph0(std::vector<ov::PartialS
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
}
static std::shared_ptr<ov::Model> initMHAQuantSubgraph1(std::vector<ov::PartialShape>& inputDynamicShapes, std::vector<ElementType>& inputPrecisions,
std::vector<ElementType>& matMulIn0Precisions) {
static std::shared_ptr<ov::Model> initMHAQuantSubgraph1(const std::vector<ov::PartialShape>& inputDynamicShapes,
const std::vector<ElementType>& inputPrecisions,
const std::vector<ElementType>& matMulIn0Precisions,
const bool fakeQuantize3Exists) {
ngraph::ParameterVector ngraphParam;
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[0], inputDynamicShapes[0]);
@ -428,8 +430,11 @@ static std::shared_ptr<ov::Model> initMHAQuantSubgraph1(std::vector<ov::PartialS
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(add, 3);
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, transpose2, transA, transB);
const auto fakeQuantize2 = ngraph::builder::makeFakeQuantize(matMul1, inputPrecisions[0], 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f});
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize2, transpose3Const);
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(
fakeQuantize3Exists ?
ngraph::builder::makeFakeQuantize(matMul1, inputPrecisions[0], 256, {}, { 0.0f }, { 2.55f }, { 0.0f }, { 2.55f }) :
matMul1,
transpose3Const);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
@ -503,7 +508,9 @@ protected:
if (patternType == 0) {
function = initMHAQuantSubgraph0(inputDynamicShapes, inputPrecisions, matMulIn0Precisions);
} else if (patternType == 1) {
function = initMHAQuantSubgraph1(inputDynamicShapes, inputPrecisions, matMulIn0Precisions);
function = initMHAQuantSubgraph1(inputDynamicShapes, inputPrecisions, matMulIn0Precisions, true);
} else if (patternType == 2) {
function = initMHAQuantSubgraph1(inputDynamicShapes, inputPrecisions, matMulIn0Precisions, false);
} else {
FAIL() << "Unsupported MHA pattern type";
}
@ -559,7 +566,7 @@ std::vector<std::vector<ElementType>> matMulIn0PrecisionsQuant = {
};
std::vector<size_t> patternTypesQuant = {
0, 1
0, 1, 2
};
INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant, MHAQuantTest,