[CPU] MatMul: move i8 compressed weights constant folding to the plugin (#18718)

- Reused LPT pass for ConstantFold pass disabling for decompression subgraph
- GraphOptimizer: added FuseFCAndWeightsDecompression transformation
- Transformation pipeline adapted to the MatMuls with compressed weights
- Added MoveFCReshapeToWeights CPU transformation
This commit is contained in:
Vladislav Golubev 2023-07-31 10:44:09 +02:00 committed by GitHub
parent 349d159327
commit 2b5ca40eb6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 814 additions and 54 deletions

View File

@ -50,6 +50,10 @@ ngraph::pass::low_precision::ConvertSubtractConstant::ConvertSubtractConstant(co
const auto quantizePrecision = weightsConvert->get_input_element_type(0);
const auto dequantizationPrecision = weightsConvert->get_output_element_type(0);
if (transformation_callback(m.get_match_root())) {
return false;
}
// validation by Convert operation input precisions
if (!constantPrecisions.empty()) {
const ngraph::element::Type inputPrecision = quantizePrecision;

View File

@ -469,3 +469,53 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstant
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}
TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationFoldSubConst) {
// Input graph: After transformation:
//
// Constant Constant Constant
// |U8 |U8 |U8
// | | |
// Convert Convert Convert(DCF) Constant
// |FP32 /F32 |FP32 /FP32
// | / \ /
// Subtract Constant Subtract Constant
// |FP32 /FP32 |FP32 /FP32
// | / \ /
// Multiply Multiply
//
// After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also all 'Convert(DCF)' node before weights is marked with 'DisableConstantFolding' attribute
// but Convert before Dequantization Sub const isn't because fold_subtract_const is set to true
{
auto weights = opset10::Constant::create(element::u8, Shape{4, 16, 1, 1}, {3});
auto convert = std::make_shared<opset10::Convert>(weights, element::f32);
auto zero_point = opset10::Constant::create(element::u8, Shape{}, {127});
auto convert_on_zero_point = std::make_shared<opset10::Convert>(zero_point, element::f32);
auto subtract = std::make_shared<opset10::Subtract>(convert, convert_on_zero_point);
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
function = std::make_shared<ov::Model>(ov::OutputVector{multiply});
}
manager.register_pass<pass::MarkDequantizationSubgraph>(element::TypeVector{element::u8}, true);
manager.register_pass<pass::ConstantFolding>();
{
auto weights = opset10::Constant::create(element::u8, Shape{4, 16, 1, 1}, {3});
auto convert = std::make_shared<opset10::Convert>(weights, element::f32);
pass::disable_constant_folding(convert);
auto zero_point = opset10::Constant::create(element::f32, Shape{}, {127});
auto subtract = std::make_shared<opset10::Subtract>(convert, zero_point);
mark_as_dequantization_node(subtract);
auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2});
auto multiply = std::make_shared<opset10::Multiply>(subtract, scale);
mark_as_dequantization_node(multiply);
function_ref = std::make_shared<ov::Model>(ov::OutputVector{multiply});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

View File

@ -22,7 +22,7 @@ namespace pass {
class TRANSFORMATIONS_API MarkDequantizationSubgraph : public MatcherPass {
public:
OPENVINO_RTTI("MarkDequantizationSubgraph", "0");
MarkDequantizationSubgraph(const element::TypeVector& precisions);
MarkDequantizationSubgraph(const element::TypeVector& precisions, const bool fold_subtract_const = false);
};
} // namespace pass
} // namespace ov

View File

@ -11,7 +11,8 @@
#include <transformations/rt_info/disable_constant_folding.hpp>
#include <transformations/utils/utils.hpp>
ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::TypeVector& precisions) {
ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::TypeVector& precisions,
const bool fold_subtract_const) {
// Dequantization subgraph may have two forms: with and without Subtract
//
// Input Input
@ -36,6 +37,10 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
auto input = pattern_map.at(input_pattern).get_node_shared_ptr();
const auto multiply = m.get_match_root();
if (transformation_callback(multiply)) {
return false;
}
auto subtract_it = pattern_map.find(subtract_pattern);
if (subtract_it == pattern_map.end()) {
for (size_t i = 0; i < multiply->get_input_size(); i++) {
@ -63,7 +68,8 @@ ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::
// mark Subtract as dequantization node
ov::mark_as_dequantization_node(subtract_it->second.get_node_shared_ptr());
auto zero_point = pattern_map.at(zero_point_pattern).get_node_shared_ptr();
if (ov::is_type<opset10::Convert>(zero_point) && input_precision == zero_point->get_input_element_type(0) &&
if (!fold_subtract_const && ov::is_type<opset10::Convert>(zero_point) &&
input_precision == zero_point->get_input_element_type(0) &&
ov::is_type<opset10::Constant>(zero_point->get_input_node_ptr(0))) {
// disable ConstantFolding also for Convert on zero_point
// so we don't have to constantfold it and then convert it back to

View File

@ -277,6 +277,101 @@ void GraphOptimizer::FuseConvMatmulFCDeconvAndDQScales(Graph &graph) {
}
}
void GraphOptimizer::FuseFCAndWeightsDecompression(Graph &graph) {
const std::set<InferenceEngine::Precision> supportedWeightsPrecisions{InferenceEngine::Precision::U8};
auto expectedNode = [](NodePtr node, Type expectedType) {
return node->getType() == expectedType && node->getChildEdges().size() == 1;
};
auto& graphNodes = graph.GetNodes();
for (size_t i = 0; i < graphNodes.size(); i++) {
const auto fcNode = dynamic_cast<node::FullyConnected*>(graphNodes[i].get());
if (fcNode == nullptr)
continue;
const auto parent = fcNode->getParentEdgesAtPort(1)[0]->getParent();
const bool withTranspose = parent->getType() == Type::Transpose;
const NodePtr transposeNode = withTranspose ? parent : nullptr;
const auto multiplyNode = withTranspose ? parent->getParentEdgesAtPort(0)[0]->getParent() : parent;
if (!expectedNode(multiplyNode, Type::Eltwise) || multiplyNode->getAlgorithm() != Algorithm::EltwiseMultiply ||
!multiplyNode->isConstant())
continue;
CPU_GRAPH_OPTIMIZER_SCOPE(FuseFCAndWeightsDecompression);
const auto multiplyConstNode = multiplyNode->getParentEdgesAtPort(1)[0]->getParent();
if (!expectedNode(multiplyConstNode, Type::Input))
continue;
const auto mulParent = multiplyNode->getParentEdgesAtPort(0)[0]->getParent();
const bool withSubtract = mulParent->getAlgorithm() == Algorithm::EltwiseSubtract;
NodePtr subtractNode, subtractConstNode;
if (withSubtract) {
subtractNode = mulParent;
if (!expectedNode(subtractNode, Type::Eltwise))
continue;
subtractConstNode = subtractNode->getParentEdgesAtPort(1)[0]->getParent();
if (!expectedNode(subtractConstNode, Type::Input))
continue;
}
const auto convertNode = withSubtract ? subtractNode->getParentEdgesAtPort(0)[0]->getParent() : mulParent;
if (!expectedNode(convertNode, Type::Convert))
continue;
const auto weightsNode = convertNode->getParentEdgesAtPort(0)[0]->getParent();
if (!expectedNode(weightsNode, Type::Input))
continue;
// Precision limitations
if (multiplyConstNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32)
continue;
if (supportedWeightsPrecisions.find(weightsNode->getOriginalOutputPrecisionAtPort(0)) == supportedWeightsPrecisions.end())
continue;
if (withSubtract && subtractConstNode->getOriginalOutputPrecisionAtPort(0) != Precision::FP32)
continue;
// Shape limitations
const auto weightsShape = weightsNode->getOutputShapeAtPort(0);
const auto fcInputWeightsShape = multiplyNode->getOutputShapeAtPort(0);
if (weightsShape != fcInputWeightsShape)
continue;
const auto expectedDims = withTranspose ? VectorDims{1, weightsShape.getDims()[1]}
: VectorDims{weightsShape.getDims()[0], 1};
if (multiplyConstNode->getOutputShapeAtPort(0).getDims() != expectedDims)
continue;
if (withSubtract && subtractConstNode->getOutputShapeAtPort(0).getDims() != expectedDims)
continue;
fcNode->fuseDecompressionMultiply(multiplyConstNode);
if (withSubtract)
fcNode->fuseDecompressionSubtract(subtractConstNode);
fcNode->addOriginalLayer(multiplyNode->getOriginalLayers());
fcNode->addOriginalLayer(convertNode->getOriginalLayers());
if (withSubtract) {
fcNode->addOriginalLayer(subtractNode->getOriginalLayers());
auto subtractConstEdge = subtractConstNode->getChildEdges()[0].lock();
graph.RemoveEdge(subtractConstEdge);
}
auto multiplyConstEdge = multiplyConstNode->getChildEdges()[0].lock();
graph.RemoveEdge(multiplyConstEdge);
graph.DropNode(convertNode);
if (withSubtract)
graph.DropNode(subtractNode);
graph.DropNode(multiplyNode);
const auto& weightsPrecision = weightsNode->getOriginalOutputPrecisionAtPort(0);
if (withTranspose) {
transposeNode->setOriginalInputPrecisionAtPort(0, weightsPrecision);
transposeNode->setOriginalOutputPrecisionAtPort(0, weightsPrecision);
}
fcNode->setOriginalInputPrecisionAtPort(1, weightsPrecision);
}
}
void GraphOptimizer::FuseConvolutionMatMulDeconvAndBias(Graph &graph) {
auto& graphNodes = graph.GetNodes();

View File

@ -21,6 +21,7 @@ public:
private:
void FuseConvMatmulFCDeconvAndDQScales(Graph &graph);
void FuseFCAndWeightsDecompression(Graph &graph);
void FuseConvolutionMatMulDeconvAndBias(Graph &graph);
void FuseDeconvolutionAndSimpleOperation(Graph &graph);
void FuseMultiplyAndAdd(Graph &graph);

View File

@ -25,6 +25,7 @@
#include "common/primitive_hashing_utils.hpp"
#include "common/primitive_desc.hpp"
#include "common/primitive_desc_iface.hpp"
#include "common/cpu_convert.h"
#include <string>
#include <vector>
@ -1083,6 +1084,28 @@ bool FullyConnected::useSparseWeightsDecompression() {
return true;
}
void FullyConnected::fuseDecompressionMultiply(const NodePtr& constData) {
fuseDecompressionConstant(constData, decompressionMultiply);
}
void FullyConnected::fuseDecompressionSubtract(const NodePtr& constData) {
fuseDecompressionConstant(constData, decompressionSubtract);
}
void FullyConnected::fuseDecompressionConstant(const NodePtr& constData, std::vector<float>& decompressionValues) {
auto *constInputNode = dynamic_cast<node::Input *>(constData.get());
if (!constInputNode) {
IE_THROW() << "Cannot cast " << constData->getName() << " to Input";
}
auto constBlob = constInputNode->getMemoryPtr();
const auto elementsCount = constBlob->getDescWithType<BlockedMemoryDesc>()->getPaddedElementsCount();
decompressionValues.resize(elementsCount);
cpu_convert(constBlob->getData(),
&decompressionValues[0],
DnnlExtensionUtils::DataTypeToIEPrecision(constBlob->getDataType()),
Precision::FP32,
elementsCount);
}
} // namespace node
} // namespace intel_cpu
} // namespace ov

View File

@ -57,6 +57,12 @@ public:
void executeDynamicImpl(dnnl::stream strm) override;
bool canBeExecutedInInt8() const override;
void fuseDecompressionMultiply(const NodePtr& constData);
const std::vector<float>& getDecompressionMultiply() const { return decompressionMultiply; }
void fuseDecompressionSubtract(const NodePtr& constData);
const std::vector<float>& getDecompressionSubtract() const { return decompressionSubtract; }
private:
void createDescriptorInternal(const dnnl::memory::desc &inputDesc,
const dnnl::memory::desc &outputDesc);
@ -93,6 +99,7 @@ private:
const dnnl::engine& engine);
bool canBeExecutedInConv1x1() const;
void fuseDecompressionConstant(const NodePtr& constData, std::vector<float>& decompressionValues);
// sparse weights
bool useSparseWeights = false;
@ -107,6 +114,9 @@ private:
void executeMLAS();
void prepackMLASWeight();
#endif
std::vector<float> decompressionSubtract;
std::vector<float> decompressionMultiply;
};
} // namespace node

View File

@ -4,7 +4,10 @@
#include "transformations/cpu_opset/common/op/fully_connected.hpp"
#include "convert_matmul_to_fc.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/op/matmul.hpp>
#include <ngraph/op/convert.hpp>
#include <ngraph/op/transpose.hpp>
#include <ngraph/op/reshape.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <transformations/utils/utils.hpp>
@ -14,13 +17,16 @@
ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
MATCHER_SCOPE(ConvertMatMulToFC);
auto activations_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank());
auto weights_m = ngraph::pattern::wrap_type<ngraph::opset1::Constant, ngraph::opset1::Convert>(ngraph::pattern::has_static_rank());
auto matmul_m = ngraph::pattern::wrap_type<ngraph::opset1::MatMul>({ activations_m, weights_m }, ngraph::pattern::has_static_rank());
auto weights_path = [](const ov::Output<ov::Node>& output) {
return ov::op::util::is_on_constant_path(output.get_node_shared_ptr());
};
auto weights_m = ngraph::pattern::any_input(weights_path);
auto matmul_m = ngraph::pattern::wrap_type<ngraph::op::v0::MatMul>({ activations_m, weights_m }, ngraph::pattern::has_static_rank());
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
auto matmul = std::dynamic_pointer_cast<ngraph::opset1::MatMul>(pattern_map.at(matmul_m).get_node_shared_ptr());
auto matmul = std::dynamic_pointer_cast<ngraph::op::v0::MatMul>(pattern_map.at(matmul_m).get_node_shared_ptr());
if (!matmul || transformation_callback(matmul)) {
return false;
}
@ -30,7 +36,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
auto fc_input_a = pattern_map.at(activations_m);
auto fc_input_b = pattern_map.at(weights_m);
bool is_convert = false;
if (auto convert_node = std::dynamic_pointer_cast<ngraph::opset1::Convert>(fc_input_b.get_node_shared_ptr())) {
if (auto convert_node = std::dynamic_pointer_cast<ngraph::op::v0::Convert>(fc_input_b.get_node_shared_ptr())) {
if (is_decompression(convert_node)) {
is_convert = true;
fc_input_b = convert_node->get_input_node_shared_ptr(0);
@ -74,7 +80,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
shape_b_aligned.insert(shape_b_aligned.begin(), 1);
}
if (matmul->get_transpose_a() && rank_a != 1) {
if (matmul->get_transpose_a()) {
std::swap(*(shape_a_aligned.end() - 1), *(shape_a_aligned.end() - 2));
}
if (matmul->get_transpose_b()) {
@ -102,14 +108,13 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
ngraph::NodeVector new_ops;
auto create_transpose = [this, &new_ops ](const ngraph::Output<ngraph::Node>& node, const std::string& transpose_name) {
auto rank = node.get_partial_shape().rank();
std::vector<size_t> transpose_order(rank.get_length());
std::vector<size_t> transpose_order(node.get_partial_shape().size());
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2));
auto transpose_const = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ transpose_order.size() }, transpose_order);
auto transpose = ov::op::util::make_try_fold<ngraph::opset1::Transpose>(node, transpose_const);
if (!ngraph::is_type<ngraph::opset1::Constant>(transpose)) {
auto transpose_const = ngraph::op::v0::Constant::create(ngraph::element::i32, ngraph::Shape{ transpose_order.size() }, transpose_order);
auto transpose = ov::op::util::make_try_fold<ngraph::op::v1::Transpose>(node, transpose_const);
if (!ngraph::is_type<ngraph::op::v0::Constant>(transpose)) {
new_ops.push_back(transpose_const);
MatcherPass::register_new_node(transpose);
}
@ -133,25 +138,26 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
// Transferring from MatMul representation: [B, I, K] * [B, K, O] = [B, I, O]
// to FullyConnected representation: [I, K] * [K, O] = [I, O]
// Weights normalization
if (!matmul->get_transpose_b()) {
fc_input_b = create_transpose(fc_input_b, matmul->get_friendly_name() + "/transpose_b");
}
if (rank_b != 2) {
ngraph::Dimension K = *(shape_b_aligned.rbegin() + 1);
NGRAPH_CHECK(K.is_static());
std::vector<int64_t> reshape_shape_values = { -1ll, static_cast<int64_t>(K.get_length()) };
auto reshape_shape = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 2 }, reshape_shape_values);
fc_input_b = ov::op::util::make_try_fold<ngraph::opset1::Reshape>(fc_input_b, reshape_shape, false);
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(fc_input_b.get_node_shared_ptr())) {
auto k_len = K.get_length();
auto reshape_shape_values = matmul->get_transpose_b() ? std::vector<int64_t>{-1, k_len} : std::vector<int64_t>{k_len, -1};
auto reshape_shape = ngraph::op::v0::Constant::create(ngraph::element::i32, ngraph::Shape{ 2 }, reshape_shape_values);
fc_input_b = ov::op::util::make_try_fold<ngraph::op::v1::Reshape>(fc_input_b, reshape_shape, false);
if (!std::dynamic_pointer_cast<ngraph::op::v0::Constant>(fc_input_b.get_node_shared_ptr())) {
new_ops.push_back(reshape_shape);
}
new_ops.push_back(fc_input_b.get_node_shared_ptr());
}
// Weights normalization
if (!matmul->get_transpose_b()) {
fc_input_b = create_transpose(fc_input_b, matmul->get_friendly_name() + "/transpose_b");
}
// Input normalization
if (matmul->get_transpose_a() && rank_a != 1) {
if (matmul->get_transpose_a()) {
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
}

View File

@ -0,0 +1,107 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/cpu_opset/common/op/fully_connected.hpp"
#include "move_fc_reshape_to_weights.hpp"
#include <transformations/utils/utils.hpp>
#include <openvino/pass/pattern/op/wrap_type.hpp>
#include <openvino/pass/pattern/op/or.hpp>
#include <openvino/op/constant.hpp>
#include <openvino/op/convert.hpp>
#include <openvino/op/subtract.hpp>
#include <openvino/op/multiply.hpp>
#include <openvino/op/transpose.hpp>
#include <openvino/op/reshape.hpp>
#include "itt.hpp"
ov::intel_cpu::MoveFCReshapeToWeights::MoveFCReshapeToWeights() {
MATCHER_SCOPE(MoveFCReshapeToWeights);
using namespace ov::pass::pattern;
auto weights_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
auto convert_m = wrap_type<ov::op::v0::Convert>({weights_m});
auto sub_const_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
auto subtract_m = wrap_type<ov::op::v1::Subtract>({convert_m, sub_const_m});
auto mul_const_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
auto mul_with_sub_m = wrap_type<ov::op::v1::Multiply>({subtract_m, mul_const_m}, rank_equals(3));
auto mul_no_sub_m = wrap_type<ov::op::v1::Multiply>({convert_m, mul_const_m}, rank_equals(3));
auto mul_m = std::make_shared<ov::pass::pattern::op::Or>(OutputVector{mul_with_sub_m, mul_no_sub_m});
auto one_consumer_rank_2 = [](const ov::Output<ov::Node>& out) {
return consumers_count(1)(out) && rank_equals(2)(out);
};
auto reshape_const_m = wrap_type<ov::op::v0::Constant>(consumers_count(1));
auto reshape_m = wrap_type<ov::op::v1::Reshape>({mul_m, reshape_const_m}, one_consumer_rank_2);
auto transpose_const_m = wrap_type<ov::op::v0::Constant>();
auto transpose_m = wrap_type<ov::op::v1::Transpose>({reshape_m, transpose_const_m});
auto weights_input_m = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{reshape_m, transpose_m});
auto data_m = any_input();
auto fully_connected_m = wrap_type<ov::intel_cpu::FullyConnectedNode>({data_m, weights_input_m});
ov::matcher_pass_callback callback = [&](ov::pass::pattern::Matcher& m) {
const auto fully_connected = m.get_match_root();
const auto weights_path = fully_connected->get_input_node_shared_ptr(1);
const bool with_transpose = ov::is_type<ov::op::v1::Transpose>(weights_path);
if (with_transpose) {
const auto transpose_const = ov::as_type_ptr<ov::op::v0::Constant>(weights_path->get_input_node_shared_ptr(1));
if (transpose_const->cast_vector<int>() != std::vector<int>{1, 0}) {
return false;
}
}
const auto& fc_input_shape = fully_connected->get_input_shape(1);
const auto reshape = with_transpose ? weights_path->get_input_node_shared_ptr(0) : weights_path;
auto check_decompression_const = [&](const std::shared_ptr<ov::Node>& node) {
if (!ov::is_type<ov::op::v0::Constant>(node))
return false;
ov::Shape expected_shape(3, 1);
const size_t out_channels_idx = with_transpose ? 2 : 1;
expected_shape[out_channels_idx] = fc_input_shape[0];
return node->get_output_shape(0) == expected_shape;
};
const auto mul = reshape->get_input_node_shared_ptr(0);
if (!check_decompression_const(mul->get_input_node_shared_ptr(1)))
return false;
const auto mul_parent = mul->get_input_node_shared_ptr(0);
const bool with_subtract = ov::is_type<ov::op::v1::Subtract>(mul_parent);
if (with_subtract && !check_decompression_const(mul_parent->get_input_node_shared_ptr(1)))
return false;
const auto convert = with_subtract ? mul_parent->get_input_node_shared_ptr(0) : mul_parent;
const auto weights = convert->get_input_node_shared_ptr(0);
ov::Shape expected_weights_shape(3, 1);
expected_weights_shape[1] = fc_input_shape[with_transpose ? 1 : 0];
expected_weights_shape[2] = fc_input_shape[with_transpose ? 0 : 1];
if (weights->get_output_shape(0) != expected_weights_shape)
return false;
auto squeeze_constant = [](const std::shared_ptr<ov::Node>& node) {
const auto constant = ov::as_type_ptr<ov::op::v0::Constant>(node);
auto shape = constant->get_shape();
shape.erase(shape.begin());
const auto new_constant = std::make_shared<ov::op::v0::Constant>(*constant, shape);
ov::replace_node(constant, new_constant);
ov::copy_runtime_info(constant, new_constant);
new_constant->set_friendly_name(constant->get_friendly_name());
};
// We can remove 3D->2D reshape if we manually reshape all constants in the weights subgraph
ov::replace_output_update_name(reshape->output(0), reshape->input_value(0));
squeeze_constant(mul->get_input_node_shared_ptr(1));
squeeze_constant(weights);
if (with_subtract)
squeeze_constant(mul_parent->get_input_node_shared_ptr(1));
return true;
};
auto m = std::make_shared<ov::pass::pattern::Matcher>(fully_connected_m, matcher_name);
this->register_matcher(m, callback);
}

View File

@ -0,0 +1,38 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <openvino/pass/graph_rewrite.hpp>
namespace ov {
namespace intel_cpu {
/**
* This transformation is applied to the FC with compressed 3D u8 weights. It moves Reshape at the weights path to the constants
* in order to constant fold the Reshape node.
* Example:
* Weights(3D) Weights(2D)
* | |
* Convert Subtract_const(3D) Convert Subtract_const(2D)
* | / | /
* Subtract(opt) Subtract(opt)
* | Multiply_const(3D) ====> | Multiply_const(2D)
* | / | /
* Multiply Multiply
* | |
* Reshape(2D) |
* | |
* Data Transpose(opt) Data Transpose(opt)
* \ / \ /
* FullyConnected FullyConnected
*/
class MoveFCReshapeToWeights: public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("MoveFCReshapeToWeights", "0");
MoveFCReshapeToWeights();
};
} // namespace intel_cpu
} // namespace ov

View File

@ -14,6 +14,7 @@
#include "common/pass/convert_to_power_static.hpp"
#include "common/pass/convert_to_leaky_relu.hpp"
#include "common/pass/convert_to_swish_cpu.hpp"
#include "common/pass/move_fc_reshape_to_weights.hpp"
#include "transformations/convert_precision.hpp"
#include "transformations/utils/utils.hpp"
#include "common/pass/rnn_sequences_optimization.hpp"
@ -32,6 +33,8 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ngraph::Function> &nGraphF
ngraph::pass::Manager manager;
manager.set_per_pass_validation(false);
CPU_REGISTER_PASS_COMMON(manager, ConvertMatMulToFC);
CPU_REGISTER_PASS_X64(manager, MoveFCReshapeToWeights);
CPU_REGISTER_PASS_X64(manager, ov::pass::Validate);
CPU_REGISTER_PASS_COMMON(manager, AlignMatMulInputRanks);
CPU_REGISTER_PASS_COMMON(manager, ConvertTileToSeqTiles);
CPU_REGISTER_PASS_X64(manager, ConvertToPowerStatic);

View File

@ -7,6 +7,7 @@
#include "snippets/op/subgraph.hpp"
#include "snippets/utils.hpp"
#include <transformations/utils/utils.hpp>
#include <utils/general_utils.h>
#include <utils/cpu_utils.hpp>
@ -396,6 +397,11 @@ bool isSuitableChildForFusingSumActivation(const std::shared_ptr<const Node> &no
bool isSuitableReduceChild(const std::shared_ptr<const Node> &node, const int channelAxis = DEFAULT_AXIS) {
return node->get_output_element_type(0) == ov::element::f32 && isSuitableChildForFusingSimple(node, channelAxis);
}
bool isSuitableMatMulWithConstantPath(const std::shared_ptr<Node>& node) {
return ov::is_type<ov::opset1::MatMul>(node) &&
!ov::is_type<ov::opset1::Constant>(node->get_input_node_shared_ptr(1)) &&
ov::op::util::is_on_constant_path(node->input_value(1));
}
// Continue fusing chain of the passed type if the node has one child
// Otherwise mark node as FusedTerminator (Fused, but fusing chain is interrupted)
void PropagateIfHasOnlyChild(const std::shared_ptr<Node> &node, NodeFusingType nodeType) {
@ -464,6 +470,15 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
for (auto &node : m->get_ordered_ops()) {
if (is_skipped_op(node))
continue;
// We perform this check separately because we mark here only weights path
// Matmul itself will be checked further
if (isSuitableMatMulWithConstantPath(node)) {
auto markup_func = [](Node* node) {
SetSnippetsNodeType(node->shared_from_this(), snippets::pass::SnippetsNodeType::SkippedByPlugin);
};
std::unordered_set<Node*> visited;
ov::op::util::visit_shape_path(node->get_input_node_ptr(1), visited, markup_func);
}
if (isSuitableConvolutionParent(node)) {
// Initiate fusing chain
SetNodeFusingType(node, NodeFusingType::FusedWithConvolution);

View File

@ -208,6 +208,32 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
const bool useLpt = !defaultPrecisions.empty();
if (useLpt) {
CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkDequantizationSubgraph, defaultPrecisions);
} else {
// MarkDequantizationSubgraph is used even in non-LPT pipeline on X64 platforms
// in order to keep compressed u8 MatMul weights with decompression operations as is
CPU_REGISTER_PASS_X64(manager, ov::pass::MarkDequantizationSubgraph, ov::element::TypeVector{ov::element::u8}, true);
CPU_SET_CALLBACK_X64(manager, [](const_node_ptr &node) -> bool {
auto get_single_consumer = [](const_node_ptr &node) -> std::shared_ptr<ov::Node> {
const auto consumers = node->get_output_target_inputs(0);
if (consumers.size() != 1)
return nullptr;
return consumers.begin()->get_node()->shared_from_this();
};
auto consumer = get_single_consumer(node);
if (!consumer)
return true;
if (ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
} else if (ov::is_type<ov::opset1::Transpose>(consumer)) {
consumer = get_single_consumer(consumer);
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
}
}
return true;
}, ov::pass::MarkDequantizationSubgraph);
}
auto get_convert_precisions = []() {

View File

@ -188,6 +188,8 @@ std::vector<std::string> disabledTestPatterns() {
// New plugin API doesn't support changes of pre-processing
R"(.*(Auto|Multi|Hetero).*InferRequestPreprocessTest.*SetPreProcessToInputInfo.*)",
R"(.*(Auto|Multi|Hetero).*InferRequestPreprocessTest.*SetPreProcessToInferRequest.*)",
// Issue: 113727
R"(.*MatMulCompressedWeights.*)",
};
#if defined(OPENVINO_ARCH_X86)

View File

@ -110,15 +110,11 @@ public:
results.push_back(std::make_shared<ngraph::opset1::Result>(soft_max->output(i)));
function = std::make_shared<ngraph::Function>(results, input_params, "ConcatReshapeConcatPattern");
ov::pass::Serialize serializer("ngraph.xml", "ngraph.bin");
serializer.run_on_model(function);
}
};
TEST_P(ConcatReshapeConcatSubgraphTest, CompareWithRefs) {
run();
ov::pass::Serialize serializer("exec_graph_dyn.xml", "exec_graph_dyn.bin");
serializer.run_on_model(std::const_pointer_cast<ov::Model>(compiledModel.get_runtime_model()));
}
namespace {

View File

@ -0,0 +1,260 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "test_utils/fusing_test_utils.hpp"
#include "ngraph_functions/builders.hpp"
#include "shared_test_classes/base/ov_subgraph.hpp"
#include "transformations/rt_info/decompression.hpp"
using namespace ngraph;
using namespace InferenceEngine;
using namespace CPUTestUtils;
using namespace ov::test;
namespace SubgraphTestsDefinitions {
/*
* Subtract_const(U8)
* /
* Weights(U8) Convert(F32)
* | /
* Convert(F32) Reshape
* \ / Multiply_const(F32)
* Subtract(opt) /
* \ Reshape
* \ /
* Multiply
* |
* Data(F32) Transpose(opt)
* \ /
* Matmul
* |
* Bias
*/
using MatmulWeightsDecompressionParams = std::tuple<std::vector<InputShape>, // input shapes
ov::test::ElementType, // weights precision
bool, // transpose on weights
bool, // decompression subtract
bool, // reshape on decompression constants
std::map<std::string, std::string>, // additional config
fusingSpecificParams>;
class MatmulWeightsDecompression : public testing::WithParamInterface<MatmulWeightsDecompressionParams>,
virtual public SubgraphBaseTest,
public CpuTestWithFusing {
public:
static std::string getTestCaseName(testing::TestParamInfo<MatmulWeightsDecompressionParams> obj) {
std::vector<InputShape> inputShapes;
ov::test::ElementType weights_precision;
bool transpose;
bool decompression_sub;
bool reshape_on_decompression;
std::map<std::string, std::string> additional_config;
fusingSpecificParams fusing_params;
std::tie(inputShapes,
weights_precision,
transpose,
decompression_sub,
reshape_on_decompression,
additional_config,
fusing_params) = obj.param;
std::ostringstream result;
for (const auto& shape : inputShapes) {
result << ov::test::utils::partialShape2str({shape.first}) << "_";
}
result << "TS=";
for (const auto& shape : inputShapes) {
result << "(";
if (!shape.second.empty()) {
auto itr = shape.second.begin();
do {
result << ov::test::utils::vec2str(*itr);
} while (++itr != shape.second.end() && result << "_");
}
result << ")_";
}
result << "weights_precision=" << weights_precision << "_";
result << "transpose_weights=" << transpose << "_";
result << "decompression_subtract=" << decompression_sub << "_";
result << "reshape_on_decompression=" << reshape_on_decompression << "_";
result << "config=(";
for (const auto& configEntry : additional_config) {
result << configEntry.first << ", " << configEntry.second << ":";
}
result << ")";
result << CpuTestWithFusing::getTestCaseName(fusing_params);
return result.str();
}
protected:
std::shared_ptr<ov::Model> initSubgraph(std::vector<ov::PartialShape>& inputShapes,
const ov::element::Type data_precision,
const ov::element::Type weights_precision,
const bool transpose_weights,
const bool add_subtract,
const bool reshape_on_decompression) {
auto params = builder::makeDynamicParams(data_precision, {inputShapes[0]});
auto transpose_if_necessary = [&](const ov::Shape& shape) {
if (!transpose_weights)
return shape;
auto transposed_shape = shape;
std::swap(*transposed_shape.rbegin(), *(transposed_shape.rbegin() + 1));
return transposed_shape;
};
auto weights_shape = transpose_if_necessary(inputShapes[1].to_shape());
auto weights = ngraph::builder::makeConstant<uint8_t>(weights_precision, weights_shape, {}, true);
weights->set_friendly_name("Compressed_weights");
auto weights_convert = std::make_shared<ngraph::opset1::Convert>(weights, data_precision);
std::shared_ptr<ov::Node> mul_parent = weights_convert;
auto output_channels = transpose_weights ? *(weights_shape.rbegin() + 1) : *weights_shape.rbegin();
auto scaleshift_target_shape = transpose_if_necessary(ov::Shape{1, output_channels});
auto scaleshift_const_shape = reshape_on_decompression ? ov::Shape{output_channels} : scaleshift_target_shape;
if (add_subtract) {
auto shift_const = ngraph::builder::makeConstant<uint8_t>(weights_precision, scaleshift_const_shape, {}, true);
std::shared_ptr<ov::Node> shift_convert = std::make_shared<ngraph::opset1::Convert>(shift_const, data_precision);
if (reshape_on_decompression) {
auto shift_reshape_const = ov::opset10::Constant::create(ov::element::i32, {scaleshift_target_shape.size()}, scaleshift_target_shape);
auto shift_reshape = std::make_shared<ov::opset10::Reshape>(shift_convert, shift_reshape_const, false);
shift_convert = shift_reshape;
}
mul_parent = std::make_shared<ov::opset10::Subtract>(weights_convert, shift_convert);
}
std::shared_ptr<ov::Node> scale_const = ngraph::builder::makeConstant<float>(data_precision, scaleshift_const_shape, {}, true);
if (reshape_on_decompression) {
auto scale_reshape_const = ov::opset10::Constant::create(ov::element::i32, {scaleshift_target_shape.size()}, scaleshift_target_shape);
auto scale_reshape = std::make_shared<ov::opset10::Reshape>(scale_const, scale_reshape_const, false);
scale_const = scale_reshape;
}
auto multiply = std::make_shared<ov::opset10::Multiply>(mul_parent, scale_const);
std::shared_ptr<ov::Node> matmul_weights = multiply;
if (transpose_weights) {
const size_t rank = matmul_weights->get_output_partial_shape(0).size();
std::vector<int> order(rank);
std::iota(order.begin(), order.end(), 0);
std::swap(*order.rbegin(), *(order.rbegin() + 1));
auto transpose_constant = ov::opset10::Constant::create(ov::element::i32, {rank}, order);
auto transpose = std::make_shared<ov::opset10::Transpose>(matmul_weights, transpose_constant);
matmul_weights = transpose;
}
auto matMul = builder::makeMatMul(params[0], matmul_weights);
return makeNgraphFunction(data_precision, params, matMul, "MatmulWeightsDecompression");
}
void SetUp() override {
targetDevice = ov::test::utils::DEVICE_CPU;
std::vector<InputShape> inputShapes;
ov::test::ElementType weights_precision;
bool transpose_weights;
bool decompression_sub;
bool reshape_on_decompression;
std::map<std::string, std::string> additional_config;
fusingSpecificParams fusing_params;
std::tie(inputShapes,
weights_precision,
transpose_weights,
decompression_sub,
reshape_on_decompression,
additional_config,
fusing_params) = GetParam();
configuration.insert(additional_config.begin(), additional_config.end());
std::tie(postOpMgrPtr, fusedOps) = fusing_params;
init_input_shapes(inputShapes);
ElementType netType = element::f32;
if (additional_config[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES)
netType = ElementType::bf16;
inType = outType = netType;
function = initSubgraph(inputDynamicShapes, netType, weights_precision, transpose_weights, decompression_sub, reshape_on_decompression);
}
void checkResults() {
const auto& test_param = GetParam();
ov::test::ElementType weights_precision = std::get<1>(test_param);
for (const auto& n : compiledModel.get_runtime_model()->get_ordered_ops()) {
if (n->get_friendly_name() == "Compressed_weights") {
ASSERT_EQ(n->get_output_element_type(0), weights_precision);
}
}
std::map<std::string, std::string> additional_config = std::get<5>(test_param);
const size_t expected_count = additional_config[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES ? 1 : 0;
CheckNumberOfNodesWithType(compiledModel, "Convert", expected_count);
CheckNumberOfNodesWithType(compiledModel, "Eltwise", expected_count);
CheckNumberOfNodesWithType(compiledModel, "Subgraph", 0);
}
};
TEST_P(MatmulWeightsDecompression, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED()
run();
checkResults();
}
namespace {
std::vector<std::map<std::string, std::string>> filterAdditionalConfig() {
std::vector<std::map<std::string, std::string>> additional_config{CPUTestUtils::cpuEmptyPluginConfig};
if (with_cpu_x86_avx512_core())
additional_config.push_back({{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}});
return additional_config;
}
const std::vector<ov::test::ElementType> weights_precisions = {ov::element::u8};
const std::vector<std::vector<InputShape>> input_shapes_basic = {
{{{-1, -1, -1}, {{1, 4, 16}, {10, 16, 16}}}, {{}, {{16, 32}}}},
{{{}, {{1, 4, 16}}}, {{}, {{1, 16, 32}}}},
{{{}, {{10, 40, 496}}}, {{}, {{1, 496, 240}}}},
{{{}, {{1, 4, 32}}}, {{}, {{32, 256}}}},
{{{}, {{1, 4, 48}}}, {{}, {{48, 256}}}},
{{{}, {{1, 4, 512}}}, {{}, {{512, 256}}}},
{{{}, {{1, 16, 32}}}, {{}, {{32, 64}}}},
};
const std::vector<fusingSpecificParams> fusingParamsSet {
emptyFusingSpec,
fusingBias,
};
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_basic,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_basic),
::testing::ValuesIn(weights_precisions),
::testing::Values(true),
::testing::Values(true),
::testing::Values(true),
::testing::ValuesIn(filterAdditionalConfig()),
::testing::ValuesIn(fusingParamsSet)),
MatmulWeightsDecompression::getTestCaseName);
const std::vector<std::vector<InputShape>> input_shapes_corner_cases = {
{{{-1, -1, -1}, {{1, 4, 16}}}, {{}, {{1, 16, 32}}}},
{{{-1, -1, -1}, {{1, 4, 16}}}, {{}, {{16, 32}}}},
};
const std::vector<bool> transpose_weights = {true, false};
const std::vector<bool> add_decompression_sub = {true, false};
const std::vector<bool> reshape_on_decompression = {true, false};
INSTANTIATE_TEST_SUITE_P(smoke_MatMulCompressedWeights_corner_cases,
MatmulWeightsDecompression,
::testing::Combine(::testing::ValuesIn(input_shapes_corner_cases),
::testing::ValuesIn(weights_precisions),
::testing::ValuesIn(transpose_weights),
::testing::ValuesIn(add_decompression_sub),
::testing::ValuesIn(reshape_on_decompression),
::testing::Values(CPUTestUtils::cpuEmptyPluginConfig),
::testing::Values(emptyFusingSpec)),
MatmulWeightsDecompression::getTestCaseName);
} // namespace
} // namespace SubgraphTestsDefinitions

View File

@ -35,7 +35,7 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest1) {
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
auto transpose_constant = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 0, 2, 1 });
auto transpose_constant = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 0, 2, 1 });
auto transpose = std::make_shared<ngraph::opset1::Transpose>(input1, transpose_constant);
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 2 }, { 1 });
auto matmul = std::make_shared<FullyConnectedNode>(transpose, input2, ngraph::Rank(3));
@ -280,29 +280,7 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest_second_input_rank_adj_2) {
}
}
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_second_input_rank_adj_3_with_bias) {
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 });
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 2, 3 }, { 1 });
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, weights, false, true);
auto biases = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 1, 2 }, { 1 });
auto add = std::make_shared<ngraph::opset1::Add>(matmul, biases);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
manager.register_pass<ConvertMatMulToFC>();
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 });
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 });
auto matmul = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
auto biases = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 1, 2 }, { 1 });
auto add = std::make_shared<ngraph::opset1::Add>(matmul, biases);
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ add }, ngraph::ParameterVector{ input1 });
}
}
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_second_input_rank_adj_3_without_bias) {
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_second_input_rank_adj_3) {
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 });
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 1, 2, 3 }, { 1 });
@ -315,7 +293,7 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest_second_input_rank_adj_3_witho
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 5, 2, 3 });
auto weights = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{ 2, 3 }, { 1 });
auto matmul = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(2));
auto matmul = std::make_shared<FullyConnectedNode>(input1, weights, ngraph::Rank(3));
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
}
}
@ -354,7 +332,7 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest_decompress_convert_1) {
}
{
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
auto transpose_constant = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 0, 2, 1 });
auto transpose_constant = ngraph::opset1::Constant::create(ngraph::element::i32, ngraph::Shape{ 3 }, { 0, 2, 1 });
auto transpose = std::make_shared<ngraph::opset1::Transpose>(input1, transpose_constant);
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f16, ngraph::Shape{2, 2 }, { 1 });
auto convert = std::make_shared<ngraph::opset1::Convert>(input2, ngraph::element::f32);
@ -362,4 +340,37 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest_decompress_convert_1) {
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
}
}
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_compressed_u8_weights) {
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
auto weights = ngraph::opset1::Constant::create(ngraph::element::u8, ngraph::Shape{1, 2, 2}, {1});
auto convert = std::make_shared<ngraph::opset1::Convert>(weights, ngraph::element::f32);
auto sub_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 1, 2}, {1});
auto sub = std::make_shared<ngraph::opset1::Subtract>(convert, sub_const);
auto mul_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 1, 2}, {1});
auto mul = std::make_shared<ngraph::opset1::Multiply>(sub, mul_const);
auto matmul = std::make_shared<ngraph::opset1::MatMul>(data, mul);
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{matmul}, ngraph::ParameterVector{data});
manager.register_pass<ConvertMatMulToFC>();
}
{
auto data = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{3, 2, 2});
auto weights = ngraph::opset1::Constant::create(ngraph::element::u8, ngraph::Shape{1, 2, 2}, {1});
auto convert = std::make_shared<ngraph::opset1::Convert>(weights, ngraph::element::f32);
auto sub_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 1, 2}, {1});
auto sub = std::make_shared<ngraph::opset1::Subtract>(convert, sub_const);
auto mul_const = ngraph::opset1::Constant::create(ngraph::element::f32, ngraph::Shape{1, 1, 2}, {1});
auto mul = std::make_shared<ngraph::opset1::Multiply>(sub, mul_const);
auto reshape_const = ngraph::opset1::Constant::create(ov::element::i32, {2}, {2, -1});
auto reshape = std::make_shared<ngraph::opset1::Reshape>(mul, reshape_const, false);
auto transpose_const = ngraph::opset1::Constant::create(ov::element::i32, {2}, {1, 0});
auto transpose = std::make_shared<ngraph::opset1::Transpose>(reshape, transpose_const);
auto matmul = std::make_shared<FullyConnectedNode>(data, transpose, ngraph::Rank(3));
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ data });
}
}

View File

@ -0,0 +1,107 @@
// Copyright (C) 2018-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <transformations/cpu_opset/common/pass/move_fc_reshape_to_weights.hpp>
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <openvino/core/model.hpp>
#include <openvino/opsets/opset1.hpp>
#include <transformations/cpu_opset/common/op/fully_connected.hpp>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace testing;
using namespace ov::intel_cpu;
using MoveFCReshapeToWeightsParams = std::tuple<std::pair<ov::PartialShape, ov::Shape>, // data_shape - weights_shape
bool, // add transpose
bool>; // add subtract
class MoveFCReshapeToWeightsTests : public TransformationTestsF, public WithParamInterface<MoveFCReshapeToWeightsParams> {
public:
static std::string getTestCaseName(testing::TestParamInfo<MoveFCReshapeToWeightsParams> obj) {
std::pair<ov::PartialShape, ov::Shape> input_shapes;
bool add_transpose;
bool add_subtract;
std::tie(input_shapes, add_transpose, add_subtract) = obj.param;
std::ostringstream result;
result << "Input_shape=(" << input_shapes.first << ")_Weights_shape=(" << input_shapes.second
<< ")_add_transpose=" << add_transpose << "_add_subtract=" << add_subtract;
return result.str();
}
static std::shared_ptr<ov::Model> initModel(const ov::PartialShape& data_shape,
const ov::Shape& weights_shape,
const bool add_transpose,
const bool add_subtract,
const bool add_reshape) {
auto data = std::make_shared<ov::opset1::Parameter>(ov::element::f32, data_shape);
auto transposed_shape = weights_shape;
if (add_transpose)
std::swap(*(transposed_shape.rbegin() + 1), *transposed_shape.rbegin());
std::shared_ptr<ov::Node> weights_path = ov::opset1::Constant::create(ov::element::u8, transposed_shape, {1});
weights_path = std::make_shared<ov::opset1::Convert>(weights_path, ov::element::f32);
ov::Shape decompression_shape(weights_shape.size(), 1);
const size_t n_idx = add_transpose ? transposed_shape.size() - 1 : transposed_shape.size() - 2;
decompression_shape[n_idx] = transposed_shape[n_idx];
if (add_subtract) {
auto sub_const = ov::opset1::Constant::create(ov::element::f32, decompression_shape, {1});
weights_path = std::make_shared<ov::opset1::Subtract>(weights_path, sub_const);
}
auto mul_const = ov::opset1::Constant::create(ov::element::f32, decompression_shape, {1});
weights_path = std::make_shared<ov::opset1::Multiply>(weights_path, mul_const);
if (add_reshape) {
auto target_shape = transposed_shape;
target_shape.erase(target_shape.begin());
auto reshape_const = ov::opset1::Constant::create(ov::element::i32, {2}, target_shape);
weights_path = std::make_shared<ov::opset1::Reshape>(weights_path, reshape_const, false);
}
if (add_transpose) {
auto transpose_const = ov::opset1::Constant::create(ov::element::i32, {2}, {1, 0});
weights_path = std::make_shared<ov::opset1::Transpose>(weights_path, transpose_const);
}
auto fully_connected = std::make_shared<FullyConnectedNode>(data, weights_path, ov::Rank(3));
return std::make_shared<ov::Model>(ov::NodeVector{fully_connected}, ov::ParameterVector{data});
}
protected:
void SetUp() override {
TransformationTestsF::SetUp();
std::pair<ov::PartialShape, ov::Shape> input_shapes;
bool add_transpose;
bool add_subtract;
std::tie(input_shapes, add_transpose, add_subtract) = this->GetParam();
ov::Shape ref_weights_shape = input_shapes.second;
ref_weights_shape.erase(ref_weights_shape.begin());
model = initModel(input_shapes.first, input_shapes.second, add_transpose, add_subtract, true);
model_ref = initModel(input_shapes.first, ref_weights_shape, add_transpose, add_subtract, false);
manager.register_pass<MoveFCReshapeToWeights>();
}
};
TEST_P(MoveFCReshapeToWeightsTests, CompareFunctions) {}
const std::vector<std::pair<ov::PartialShape, ov::Shape>> input_shapes_wo_transpose = {
{{-1, -1, -1}, {1, 4, 3}}
};
const std::vector<bool> add_transpose = {false, true};
const std::vector<bool> add_subtract = {false, true};
INSTANTIATE_TEST_SUITE_P(TransformationTests_wo_transpose, MoveFCReshapeToWeightsTests,
::testing::Combine(
::testing::ValuesIn(input_shapes_wo_transpose),
::testing::ValuesIn(add_transpose),
::testing::ValuesIn(add_subtract)),
MoveFCReshapeToWeightsTests::getTestCaseName);