[CPU] Quantized MHA extension for SmoothQuant (#17906)
This commit is contained in:
parent
2547301fa7
commit
655c21adf1
@ -8,6 +8,7 @@
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include "transformations/cpu_opset/x64/op/mha.hpp"
|
||||
#include "simplify_fakequantize.hpp"
|
||||
|
||||
@ -470,7 +471,8 @@ ov::intel_cpu::MHAQuantFusion2::MHAQuantFusion2() {
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>(),
|
||||
ngraph::pattern::wrap_type<ngraph::opset4::Constant>()});
|
||||
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(fakeQuantize1, in10);
|
||||
auto in11 = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{ matmul1, fakeQuantize1 });
|
||||
auto transpose3 = std::make_shared<ngraph::opset3::Transpose>(in11, in10);
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
auto& pattern_to_output = m.get_pattern_value_map();
|
||||
@ -532,10 +534,13 @@ ov::intel_cpu::MHAQuantFusion2::MHAQuantFusion2() {
|
||||
return false;
|
||||
|
||||
std::vector<float> fq1_scale;
|
||||
if (auto fq_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
|
||||
fq1_scale = simplifyToScale(fq_node);
|
||||
if (!fq1_scale.size())
|
||||
return false;
|
||||
const bool fakeQuantize1Exists = pattern_to_output.find(fakeQuantize1) != pattern_to_output.end();
|
||||
if (fakeQuantize1Exists) {
|
||||
if (auto fq_node = ngraph::as_type_ptr<ngraph::opset1::FakeQuantize>(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr())) {
|
||||
fq1_scale = simplifyToScale(fq_node);
|
||||
if (!fq1_scale.size())
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto matmul1_node = ngraph::as_type_ptr<ngraph::opset3::MatMul>(pattern_to_output.at(matmul1).get_node_shared_ptr());
|
||||
@ -551,19 +556,23 @@ ov::intel_cpu::MHAQuantFusion2::MHAQuantFusion2() {
|
||||
fq0_node->get_output_element_type(0), ngraph::element::undefined, ngraph::element::undefined,
|
||||
transpose3_node->get_output_element_type(0));
|
||||
mha->set_friendly_name(m.get_match_root()->get_friendly_name());
|
||||
ngraph::copy_runtime_info({pattern_to_output.at(transpose0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(softmax).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose2).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(fakeQuantize1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose3).get_node_shared_ptr(),
|
||||
},
|
||||
mha);
|
||||
std::vector<std::shared_ptr<Node>> merged = {
|
||||
pattern_to_output.at(transpose0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(fakeQuantize0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul0).get_node_shared_ptr(),
|
||||
pattern_to_output.at(mul).get_node_shared_ptr(),
|
||||
pattern_to_output.at(add).get_node_shared_ptr(),
|
||||
pattern_to_output.at(softmax).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose2).get_node_shared_ptr(),
|
||||
pattern_to_output.at(matmul1).get_node_shared_ptr(),
|
||||
pattern_to_output.at(transpose3).get_node_shared_ptr(),
|
||||
};
|
||||
|
||||
if (fakeQuantize1Exists) {
|
||||
merged.push_back(pattern_to_output.at(fakeQuantize1).get_node_shared_ptr());
|
||||
}
|
||||
ngraph::copy_runtime_info(merged, mha);
|
||||
|
||||
if (transformation_callback(mha)) {
|
||||
return false;
|
||||
|
@ -372,8 +372,10 @@ static std::shared_ptr<ov::Model> initMHAQuantSubgraph0(std::vector<ov::PartialS
|
||||
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
|
||||
}
|
||||
|
||||
static std::shared_ptr<ov::Model> initMHAQuantSubgraph1(std::vector<ov::PartialShape>& inputDynamicShapes, std::vector<ElementType>& inputPrecisions,
|
||||
std::vector<ElementType>& matMulIn0Precisions) {
|
||||
static std::shared_ptr<ov::Model> initMHAQuantSubgraph1(const std::vector<ov::PartialShape>& inputDynamicShapes,
|
||||
const std::vector<ElementType>& inputPrecisions,
|
||||
const std::vector<ElementType>& matMulIn0Precisions,
|
||||
const bool fakeQuantize3Exists) {
|
||||
ngraph::ParameterVector ngraphParam;
|
||||
|
||||
auto transpose0Param = std::make_shared<ngraph::opset1::Parameter>(inputPrecisions[0], inputDynamicShapes[0]);
|
||||
@ -428,8 +430,11 @@ static std::shared_ptr<ov::Model> initMHAQuantSubgraph1(std::vector<ov::PartialS
|
||||
const auto softMax = std::make_shared<ngraph::opset1::Softmax>(add, 3);
|
||||
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
|
||||
const auto matMul1 = std::make_shared<ngraph::opset3::MatMul>(softMax, transpose2, transA, transB);
|
||||
const auto fakeQuantize2 = ngraph::builder::makeFakeQuantize(matMul1, inputPrecisions[0], 256, {}, {0.0f}, {2.55f}, {0.0f}, {2.55f});
|
||||
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fakeQuantize2, transpose3Const);
|
||||
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(
|
||||
fakeQuantize3Exists ?
|
||||
ngraph::builder::makeFakeQuantize(matMul1, inputPrecisions[0], 256, {}, { 0.0f }, { 2.55f }, { 0.0f }, { 2.55f }) :
|
||||
matMul1,
|
||||
transpose3Const);
|
||||
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(transpose3)};
|
||||
return std::make_shared<ngraph::Function>(results, ngraphParam, "mha");
|
||||
@ -503,7 +508,9 @@ protected:
|
||||
if (patternType == 0) {
|
||||
function = initMHAQuantSubgraph0(inputDynamicShapes, inputPrecisions, matMulIn0Precisions);
|
||||
} else if (patternType == 1) {
|
||||
function = initMHAQuantSubgraph1(inputDynamicShapes, inputPrecisions, matMulIn0Precisions);
|
||||
function = initMHAQuantSubgraph1(inputDynamicShapes, inputPrecisions, matMulIn0Precisions, true);
|
||||
} else if (patternType == 2) {
|
||||
function = initMHAQuantSubgraph1(inputDynamicShapes, inputPrecisions, matMulIn0Precisions, false);
|
||||
} else {
|
||||
FAIL() << "Unsupported MHA pattern type";
|
||||
}
|
||||
@ -559,7 +566,7 @@ std::vector<std::vector<ElementType>> matMulIn0PrecisionsQuant = {
|
||||
};
|
||||
|
||||
std::vector<size_t> patternTypesQuant = {
|
||||
0, 1
|
||||
0, 1, 2
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_MHAQuant, MHAQuantTest,
|
||||
|
Loading…
Reference in New Issue
Block a user