[CPU] f16 constant folding on cpu plug-in side for MatMul only (#18079)

This commit is contained in:
Anton Voronov 2023-07-20 08:58:46 +04:00 committed by GitHub
parent 0e76496acc
commit 60e40843c0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 356 additions and 12 deletions

View File

@ -14,6 +14,7 @@ namespace pass {
class TRANSFORMATIONS_API EnableDecompressionConvertConstantFolding;
class TRANSFORMATIONS_API DisableDecompressionConvertConstantFolding;
class TRANSFORMATIONS_API KeepConstAndDecompression;
class TRANSFORMATIONS_API KeepConstAndDecompressionForMatMul;
} // namespace pass
} // namespace ov
@ -47,3 +48,14 @@ public:
OPENVINO_RTTI("KeepConstAndDecompression", "0");
KeepConstAndDecompression();
};
/**
* @ingroup ie_transformation_common_api
* @brief Disables ConstantFolding for Convert operation (just before MatMul operation only) and prevents conversion
* of f16 Consts to f32.
*/
class ov::pass::KeepConstAndDecompressionForMatMul : public MatcherPass {
public:
OPENVINO_RTTI("KeepConstAndDecompressionForMatMul", "0");
KeepConstAndDecompressionForMatMul();
};

View File

@ -7,6 +7,7 @@
#include "itt.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/convert.hpp"
#include "openvino/op/matmul.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/decompression.hpp"
#include "transformations/rt_info/disable_constant_folding.hpp"
@ -61,11 +62,36 @@ pass::KeepConstAndDecompression::KeepConstAndDecompression() {
disable_constant_folding(node);
if (!is_type<ov::op::v0::Constant>(node->input_value(0).get_node_shared_ptr()))
return true;
return false;
enable_keep_fp16_const(node->input_value(0).get_node_shared_ptr());
return true;
return false;
};
auto m = std::make_shared<pattern::Matcher>(node_pattern, matcher_name);
register_matcher(m, callback);
}
pass::KeepConstAndDecompressionForMatMul::KeepConstAndDecompressionForMatMul() {
MATCHER_SCOPE(KeepConstAndDecompressionForMatMul);
auto matmul = pass::pattern::wrap_type<ov::op::v0::MatMul>();
matcher_pass_callback callback = [=](pass::pattern::Matcher& m) {
auto node = m.get_match_root();
// input to matmul is decompression Convert
const auto& inp_convert = node->input_value(1).get_node_shared_ptr();
if (!is_type<ov::op::v0::Convert>(inp_convert) || !is_decompression(inp_convert))
return false;
disable_constant_folding(inp_convert);
if (!is_type<ov::op::v0::Constant>(inp_convert->input_value(0).get_node_shared_ptr()))
return false;
enable_keep_fp16_const(inp_convert->input_value(0).get_node_shared_ptr());
return false;
};
auto m = std::make_shared<pass::pattern::Matcher>(matmul, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -9,6 +9,7 @@
#include "nodes/pooling.h"
#include "nodes/eltwise.h"
#include "nodes/concat.h"
#include "nodes/convert.h"
#include "nodes/reorder.h"
#include "nodes/conv.h"
#include "nodes/deconv.h"
@ -86,6 +87,10 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
MergeConvertAndScaleShift(graph);
graph.RemoveDroppedNodes();
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseFCAndConvertOnWeights");
FuseFCAndConvertOnWeights(graph);
graph.RemoveDroppedNodes();
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseDeconvolutionAndSimpleOperation");
FuseDeconvolutionAndSimpleOperation(graph);
graph.RemoveDroppedNodes();
@ -691,6 +696,20 @@ void GraphOptimizer::MergeConvertAndScaleShift(Graph& graph) {
}
}
void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {
// This optimization fuses Convert (fp16 -> bf16/fp32) on weights directly to FC input to allow precision conversion handling based on internal logic
// (e.g. fuse conversion with weights reordering)
auto& graphNodes = graph.GetNodes();
for (auto parent : graphNodes) {
if (parent->getType() == Type::Convert && parent->isConstant() && parent->getChildEdgeAt(0)->getChild()->getType() == Type::FullyConnected
&& parent->getOriginalInputPrecisionAtPort(0) == Precision::FP16
&& one_of(parent->getOriginalOutputPrecisionAtPort(0), Precision::FP32, Precision::BF16)) {
graph.DropNode(parent);
}
}
}
void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) {
auto& graphNodes = graph.GetNodes();

View File

@ -25,6 +25,7 @@ private:
void FuseDeconvolutionAndSimpleOperation(Graph &graph);
void FuseMultiplyAndAdd(Graph &graph);
void MergeConvertAndScaleShift(Graph& graph);
void FuseFCAndConvertOnWeights(Graph& graph);
void FuseFullyConnectedAndSimpleOperation(Graph &graph);
void FuseMatMulAndSimpleOperation(Graph &graph);
void FuseConvolutionAndSimpleOperationThroughMaxPool(Graph &graph);

View File

@ -14,7 +14,7 @@
ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
MATCHER_SCOPE(ConvertMatMulToFC);
auto activations_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank());
auto weights_m = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
auto weights_m = ngraph::pattern::wrap_type<ngraph::opset1::Constant, ngraph::opset1::Convert>(ngraph::pattern::has_static_rank());
auto matmul_m = ngraph::pattern::wrap_type<ngraph::opset1::MatMul>({ activations_m, weights_m }, ngraph::pattern::has_static_rank());
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
@ -29,6 +29,15 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
// So in case of adding new operations that takes matmul inputs we need keep update fc_input_a and fc_input_b.
auto fc_input_a = pattern_map.at(activations_m);
auto fc_input_b = pattern_map.at(weights_m);
bool is_convert = false;
if (auto convert_node = std::dynamic_pointer_cast<ngraph::opset1::Convert>(fc_input_b.get_node_shared_ptr())) {
if (is_decompression(convert_node)) {
is_convert = true;
fc_input_b = convert_node->get_input_node_shared_ptr(0);
} else {
return false;
}
}
auto shape_a = fc_input_a.get_partial_shape();
auto shape_b = fc_input_b.get_partial_shape();
@ -45,8 +54,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
// Check that if second inputs is Constant path and it's shape without ones dimensions has length <= 2
// we replace MatMul with FullyConnected operation.
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(fc_input_b.get_node_shared_ptr()) ||
std::count_if(shape_b.begin(), shape_b.end(), [](ngraph::Dimension x) { return x != 1; }) > 2) {
if (std::count_if(shape_b.begin(), shape_b.end(), [](ngraph::Dimension x) { return x != 1; }) > 2) {
return false;
}
/*
@ -147,9 +155,18 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
}
auto output_rank = matmul->get_output_partial_shape(0).rank();
// Connect Convert to new input if needed
if (is_convert) {
auto convert = pattern_map.at(weights_m).get_node_shared_ptr();
convert->input(0).replace_source_output(fc_input_b);
convert->validate_and_infer_types();
fc_input_b = convert;
}
// Create FullyConnected
auto fc = std::make_shared<ov::intel_cpu::FullyConnectedNode>(fc_input_a, fc_input_b, output_rank, matmul->get_output_element_type(0));
auto output_rank = matmul->get_output_partial_shape(0).rank();
auto fc = std::make_shared<ov::intel_cpu::FullyConnectedNode>(fc_input_a, fc_input_b, output_rank,
matmul->get_output_element_type(0));
fc->set_friendly_name(matmul->get_friendly_name());
new_ops.push_back(fc);
ngraph::copy_runtime_info(matmul, new_ops);

View File

@ -253,7 +253,7 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, con
ov::PartialShape matmul_shape;
for (const auto &parent_out : node->input_values()) {
const auto parent = parent_out.get_node_shared_ptr();
if (ov::is_type<ov::op::v0::Constant>(parent)) {
if (ov::is_type<ov::op::v0::Constant>(parent) || ov::is_type<ov::op::v0::Convert>(parent)) {
bias_shape = parent_out.get_shape();
num_non_const_inputs++;
} else {
@ -264,7 +264,8 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, con
// first check that weights are constant and both activations and weights have static shape
if (grandparents.size() == 2 &&
grandparents[1].get_partial_shape().is_static() &&
ov::is_type<ov::op::v0::Constant>(grandparents[1].get_node_shared_ptr())) {
(ov::is_type<ov::op::v0::Constant>(grandparents[1].get_node_shared_ptr())
|| ov::is_type<ov::op::v0::Convert>(grandparents[1].get_node_shared_ptr()))) {
auto rank_a = grandparents[0].get_partial_shape().rank().get_length();
auto rank_w = grandparents[1].get_partial_shape().rank().get_length();
if (rank_a != 1 && rank_w != 1 && rank_a <= 3 && rank_w <= 3)

View File

@ -202,8 +202,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
manager.set_per_pass_validation(false);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::InitNodeInfo);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkShapeOfSubgraphs);
// todo: uncomment KeepConstAndDecompression when xxx-105060 is ready
// CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompression);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompressionForMatMul);
const bool useLpt = !defaultPrecisions.empty();
if (useLpt) {
@ -464,6 +463,13 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
ov::pass::ConvertQuantizeDequantize);
}
/* In some cases, during the transformation pipeline, some MatMul nodes can be transformed into other nodes. For example, they can become part of
AUGRUCell node (see AUGRUCellFusion pass). In such cases, some constant paths will be unfolded, which can lead to crashes in the plugin. To avoid this,
we re-mark decompression converts again and finally do CF for those constant paths that are not inputs to MatMul node */
CPU_REGISTER_PASS_COMMON(manager, ov::pass::EnableDecompressionConvertConstantFolding);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompressionForMatMul);
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConstantFolding);
manager.run_passes(model);
}

View File

@ -0,0 +1,217 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils/fusing_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "transformations/rt_info/decompression.hpp"
using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils;
using namespace ov::test;
namespace SubgraphTestsDefinitions {
/* This test checks that the ConvertMatMulToFC transformation should work and the MatMul node is converted to the FC node.
* The Convert node should be removed on the CPU plugin side.
* Graph before:
------------ ------------
|Input(f32)| |Input(f16)|
------------ ------------
| |
| ---------------------------------
| |Convert(decompression f16->f32)|
| ---------------------------------
| |
-----------------------------------------------
| MatMul |
-----------------------------------------------
|
--------
|Output|
--------
* Exec graph:
------------ ------------
|Input(f32)| |Input(f16)|
------------ ------------
| |
----------------------------
| FullyConnected |
----------------------------
|
--------
|Output|
--------
*/
using MatMulDecompressConvertParams = std::tuple<
std::vector<InputShape>, // input shapes
std::pair<bool, bool>, // transposeA, transposeB
std::map<std::string, std::string> // additional config
>;
class MatMulDecompressConvertTest : public testing::WithParamInterface<MatMulDecompressConvertParams>,
virtual public SubgraphBaseTest, public CPUTestsBase {
public:
static std::string getTestCaseName(testing::TestParamInfo<MatMulDecompressConvertParams> obj) {
std::vector<InputShape> inputShapes;
std::pair<bool, bool> transpose;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, transpose, additionalConfig) = obj.param;
std::ostringstream result;
for (const auto& shape : inputShapes) {
result << CommonTestUtils::partialShape2str({shape.first}) << "_";
}
result << "TS=";
for (const auto& shape : inputShapes) {
result << "(";
if (!shape.second.empty()) {
auto itr = shape.second.begin();
do {
result << CommonTestUtils::vec2str(*itr);
} while (++itr != shape.second.end() && result << "_");
}
result << ")_";
}
result << "transpose_a=" << transpose.first << "_";
result << "transpose_b=" << transpose.second << "_";
result << "config=(";
for (const auto& configEntry : additionalConfig) {
result << configEntry.first << ", " << configEntry.second << ":";
}
result << ")";
return result.str();
}
protected:
template<typename T>
void transposeShape(T& shape) {
IE_ASSERT(shape.size() > 1);
std::swap(*(shape.end() - 1), *(shape.end() - 2));
}
void CheckConstFP16() const {
auto getExecValue = [](const ov::Node::RTMap& rtInfo, const std::string &paramName) -> std::string {
auto it = rtInfo.find(paramName);
IE_ASSERT(rtInfo.end() != it);
return it->second.as<std::string>();
};
const auto execFunction = compiledModel.get_runtime_model();
ASSERT_NE(nullptr, execFunction);
for (const auto &fcNode : execFunction->get_ops()) {
if (getExecValue(fcNode->get_rt_info(), ExecGraphInfoSerialization::LAYER_TYPE) == "FullyConnected") {
const auto &constNode = fcNode->get_input_node_shared_ptr(1);
ASSERT_EQ(getExecValue(constNode->get_rt_info(), ExecGraphInfoSerialization::LAYER_TYPE), "Const");
ASSERT_EQ(getExecValue(constNode->get_rt_info(), ExecGraphInfoSerialization::OUTPUT_PRECISIONS), "FP16");
}
}
}
void SetUp() override {
targetDevice = CommonTestUtils::DEVICE_CPU;
std::vector<InputShape> inputShapes;
std::pair<bool, bool> transpose;
std::map<std::string, std::string> additionalConfig;
std::tie(inputShapes, transpose, additionalConfig) = this->GetParam();
init_input_shapes(inputShapes);
bool transpA = transpose.first;
bool transpB = transpose.second;
if (transpA) {
transposeShape(inputDynamicShapes[0]);
for (auto& shapes : targetStaticShapes) {
transposeShape(shapes[0]);
}
}
if (transpB) {
transposeShape(inputDynamicShapes[1]);
for (auto& shapes : targetStaticShapes) {
transposeShape(shapes[1]);
}
}
const auto& inShapeA = inputDynamicShapes[0];
const auto& inShapeB = inputDynamicShapes[1];
configuration.insert(additionalConfig.begin(), additionalConfig.end());
ElementType netType = element::f32;
if (additionalConfig[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES)
inType = outType = netType = ElementType::bf16;
else
inType = outType = netType;
std::string cpuNodeType = "FullyConnected";
auto params = builder::makeDynamicParams(inType, {inShapeA});
auto paramOuts = helpers::convert2OutputVector(helpers::castOps2Nodes<opset1::Parameter>(params));
auto matrixB = ngraph::builder::makeConstant<float16>(element::f16, inShapeB.get_shape(), {}, true);
auto convert = std::make_shared<ngraph::opset1::Convert>(matrixB, inType);
mark_as_decompression(convert);
auto matMul = builder::makeMatMul(paramOuts[0], convert, transpA, transpB);
function = CPUTestsBase::makeNgraphFunction(netType, params, matMul, cpuNodeType);
}
};
TEST_P(MatMulDecompressConvertTest, CompareWithRefs) {
run();
CheckNumberOfNodesWithType(compiledModel, "FullyConnected", 1);
CheckNumberOfNodesWithType(compiledModel, "Convert", 0);
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
CheckConstFP16();
}
namespace {
const std::vector<std::pair<bool, bool>> transposeParams = {
{false, false},
{false, true},
{true, false},
{true, true},
};
std::vector<std::map<std::string, std::string>> filterAdditionalConfig() {
std::vector<std::map<std::string, std::string>> additionalConfig;
additionalConfig.push_back(std::map<std::string, std::string>{/* empty config */});
if (with_cpu_x86_avx512_core()) {
additionalConfig.push_back({{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}});
}
return additionalConfig;
}
const auto testParams2D_smoke = ::testing::Combine(
::testing::Values(static_shapes_to_test_representation({{2, 3}, {3, 4}})),
::testing::ValuesIn(transposeParams),
::testing::ValuesIn(filterAdditionalConfig()));
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D, MatMulDecompressConvertTest, testParams2D_smoke,
MatMulDecompressConvertTest::getTestCaseName);
const auto testParams3D_smoke = ::testing::Combine(
::testing::Values(static_shapes_to_test_representation({{1, 2, 3}, {3, 4}}),
static_shapes_to_test_representation({{2, 3}, {1, 3, 4}})),
::testing::ValuesIn(transposeParams),
::testing::ValuesIn(filterAdditionalConfig()));
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D, MatMulDecompressConvertTest, testParams3D_smoke,
MatMulDecompressConvertTest::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions

View File

@ -19,6 +19,7 @@
#include <ov_ops/type_relaxed.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "transformations/rt_info/decompression.hpp"
using namespace testing;
using namespace ov::intel_cpu;
@ -318,3 +319,47 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest_second_input_rank_adj_3_witho
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
}
}
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_decompress_convert_0) {
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f16, ngraph::Shape{ 1, 2, 2 }, { 1 });
auto convert = std::make_shared<ngraph::opset1::Convert>(input2, ngraph::element::f32);
ov::mark_as_decompression(convert);
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, convert, false, false);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
manager.register_pass<ConvertMatMulToFC>();
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f16, ngraph::Shape{2, 2 }, { 1 });
auto convert = std::make_shared<ngraph::opset1::Convert>(input2, ngraph::element::f32);
auto matmul = std::make_shared<FullyConnectedNode>(input1, convert, ngraph::Rank(3));
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
}
}
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_decompress_convert_1) {
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f16, ngraph::Shape{ 1, 2, 2 }, { 1 });
auto convert = std::make_shared<ngraph::opset1::Convert>(input2, ngraph::element::f32);
ov::mark_as_decompression(convert);
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, convert, true, false);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
manager.register_pass<ConvertMatMulToFC>();
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
auto transpose_constant = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 0, 2, 1 });
auto transpose = std::make_shared<ngraph::opset1::Transpose>(input1, transpose_constant);
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f16, ngraph::Shape{2, 2 }, { 1 });
auto convert = std::make_shared<ngraph::opset1::Convert>(input2, ngraph::element::f32);
auto matmul = std::make_shared<FullyConnectedNode>(transpose, convert, ngraph::Rank(3));
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
}
}

@ -1 +1 @@
Subproject commit 33bb2b261d3829162395aaa9bbe8c1c5b139e855
Subproject commit 3efb012c7f96b2d5e47270632cdf4b9e4b79b1b8