[CPU] f16 constant folding on cpu plug-in side for MatMul only (#18079)
This commit is contained in:
parent
0e76496acc
commit
60e40843c0
@ -14,6 +14,7 @@ namespace pass {
|
||||
class TRANSFORMATIONS_API EnableDecompressionConvertConstantFolding;
|
||||
class TRANSFORMATIONS_API DisableDecompressionConvertConstantFolding;
|
||||
class TRANSFORMATIONS_API KeepConstAndDecompression;
|
||||
class TRANSFORMATIONS_API KeepConstAndDecompressionForMatMul;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ov
|
||||
@ -47,3 +48,14 @@ public:
|
||||
OPENVINO_RTTI("KeepConstAndDecompression", "0");
|
||||
KeepConstAndDecompression();
|
||||
};
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Disables ConstantFolding for Convert operation (just before MatMul operation only) and prevents conversion
|
||||
* of f16 Consts to f32.
|
||||
*/
|
||||
class ov::pass::KeepConstAndDecompressionForMatMul : public MatcherPass {
|
||||
public:
|
||||
OPENVINO_RTTI("KeepConstAndDecompressionForMatMul", "0");
|
||||
KeepConstAndDecompressionForMatMul();
|
||||
};
|
||||
|
@ -7,6 +7,7 @@
|
||||
#include "itt.hpp"
|
||||
#include "openvino/op/constant.hpp"
|
||||
#include "openvino/op/convert.hpp"
|
||||
#include "openvino/op/matmul.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/decompression.hpp"
|
||||
#include "transformations/rt_info/disable_constant_folding.hpp"
|
||||
@ -61,11 +62,36 @@ pass::KeepConstAndDecompression::KeepConstAndDecompression() {
|
||||
disable_constant_folding(node);
|
||||
|
||||
if (!is_type<ov::op::v0::Constant>(node->input_value(0).get_node_shared_ptr()))
|
||||
return true;
|
||||
return false;
|
||||
enable_keep_fp16_const(node->input_value(0).get_node_shared_ptr());
|
||||
|
||||
return true;
|
||||
return false;
|
||||
};
|
||||
auto m = std::make_shared<pattern::Matcher>(node_pattern, matcher_name);
|
||||
register_matcher(m, callback);
|
||||
}
|
||||
|
||||
pass::KeepConstAndDecompressionForMatMul::KeepConstAndDecompressionForMatMul() {
|
||||
MATCHER_SCOPE(KeepConstAndDecompressionForMatMul);
|
||||
auto matmul = pass::pattern::wrap_type<ov::op::v0::MatMul>();
|
||||
|
||||
matcher_pass_callback callback = [=](pass::pattern::Matcher& m) {
|
||||
auto node = m.get_match_root();
|
||||
|
||||
// input to matmul is decompression Convert
|
||||
const auto& inp_convert = node->input_value(1).get_node_shared_ptr();
|
||||
if (!is_type<ov::op::v0::Convert>(inp_convert) || !is_decompression(inp_convert))
|
||||
return false;
|
||||
|
||||
disable_constant_folding(inp_convert);
|
||||
|
||||
if (!is_type<ov::op::v0::Constant>(inp_convert->input_value(0).get_node_shared_ptr()))
|
||||
return false;
|
||||
enable_keep_fp16_const(inp_convert->input_value(0).get_node_shared_ptr());
|
||||
|
||||
return false;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pass::pattern::Matcher>(matmul, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include "nodes/pooling.h"
|
||||
#include "nodes/eltwise.h"
|
||||
#include "nodes/concat.h"
|
||||
#include "nodes/convert.h"
|
||||
#include "nodes/reorder.h"
|
||||
#include "nodes/conv.h"
|
||||
#include "nodes/deconv.h"
|
||||
@ -86,6 +87,10 @@ void GraphOptimizer::ApplyCommonGraphOptimizations(Graph &graph) {
|
||||
MergeConvertAndScaleShift(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
|
||||
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseFCAndConvertOnWeights");
|
||||
FuseFCAndConvertOnWeights(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
|
||||
OV_ITT_SCOPE_NEXT(FIRST_INFERENCE, taskChain, "FuseDeconvolutionAndSimpleOperation");
|
||||
FuseDeconvolutionAndSimpleOperation(graph);
|
||||
graph.RemoveDroppedNodes();
|
||||
@ -691,6 +696,20 @@ void GraphOptimizer::MergeConvertAndScaleShift(Graph& graph) {
|
||||
}
|
||||
}
|
||||
|
||||
void GraphOptimizer::FuseFCAndConvertOnWeights(Graph& graph) {
|
||||
// This optimization fuses Convert (fp16 -> bf16/fp32) on weights directly to FC input to allow precision conversion handling based on internal logic
|
||||
// (e.g. fuse conversion with weights reordering)
|
||||
auto& graphNodes = graph.GetNodes();
|
||||
|
||||
for (auto parent : graphNodes) {
|
||||
if (parent->getType() == Type::Convert && parent->isConstant() && parent->getChildEdgeAt(0)->getChild()->getType() == Type::FullyConnected
|
||||
&& parent->getOriginalInputPrecisionAtPort(0) == Precision::FP16
|
||||
&& one_of(parent->getOriginalOutputPrecisionAtPort(0), Precision::FP32, Precision::BF16)) {
|
||||
graph.DropNode(parent);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void GraphOptimizer::FuseConvolutionAndZeroPoints(Graph &graph) {
|
||||
auto& graphNodes = graph.GetNodes();
|
||||
|
||||
|
@ -25,6 +25,7 @@ private:
|
||||
void FuseDeconvolutionAndSimpleOperation(Graph &graph);
|
||||
void FuseMultiplyAndAdd(Graph &graph);
|
||||
void MergeConvertAndScaleShift(Graph& graph);
|
||||
void FuseFCAndConvertOnWeights(Graph& graph);
|
||||
void FuseFullyConnectedAndSimpleOperation(Graph &graph);
|
||||
void FuseMatMulAndSimpleOperation(Graph &graph);
|
||||
void FuseConvolutionAndSimpleOperationThroughMaxPool(Graph &graph);
|
||||
|
@ -14,7 +14,7 @@
|
||||
ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
MATCHER_SCOPE(ConvertMatMulToFC);
|
||||
auto activations_m = ngraph::pattern::any_input(ngraph::pattern::has_static_rank());
|
||||
auto weights_m = ngraph::pattern::wrap_type<ngraph::opset1::Constant>();
|
||||
auto weights_m = ngraph::pattern::wrap_type<ngraph::opset1::Constant, ngraph::opset1::Convert>(ngraph::pattern::has_static_rank());
|
||||
auto matmul_m = ngraph::pattern::wrap_type<ngraph::opset1::MatMul>({ activations_m, weights_m }, ngraph::pattern::has_static_rank());
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) {
|
||||
@ -29,6 +29,15 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
// So in case of adding new operations that takes matmul inputs we need keep update fc_input_a and fc_input_b.
|
||||
auto fc_input_a = pattern_map.at(activations_m);
|
||||
auto fc_input_b = pattern_map.at(weights_m);
|
||||
bool is_convert = false;
|
||||
if (auto convert_node = std::dynamic_pointer_cast<ngraph::opset1::Convert>(fc_input_b.get_node_shared_ptr())) {
|
||||
if (is_decompression(convert_node)) {
|
||||
is_convert = true;
|
||||
fc_input_b = convert_node->get_input_node_shared_ptr(0);
|
||||
} else {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
auto shape_a = fc_input_a.get_partial_shape();
|
||||
auto shape_b = fc_input_b.get_partial_shape();
|
||||
@ -45,8 +54,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
|
||||
// Check that if second inputs is Constant path and it's shape without ones dimensions has length <= 2
|
||||
// we replace MatMul with FullyConnected operation.
|
||||
if (!std::dynamic_pointer_cast<ngraph::opset1::Constant>(fc_input_b.get_node_shared_ptr()) ||
|
||||
std::count_if(shape_b.begin(), shape_b.end(), [](ngraph::Dimension x) { return x != 1; }) > 2) {
|
||||
if (std::count_if(shape_b.begin(), shape_b.end(), [](ngraph::Dimension x) { return x != 1; }) > 2) {
|
||||
return false;
|
||||
}
|
||||
/*
|
||||
@ -147,9 +155,18 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
|
||||
fc_input_a = create_transpose(fc_input_a, matmul->get_friendly_name() + "/transpose_a");
|
||||
}
|
||||
|
||||
auto output_rank = matmul->get_output_partial_shape(0).rank();
|
||||
// Connect Convert to new input if needed
|
||||
if (is_convert) {
|
||||
auto convert = pattern_map.at(weights_m).get_node_shared_ptr();
|
||||
convert->input(0).replace_source_output(fc_input_b);
|
||||
convert->validate_and_infer_types();
|
||||
fc_input_b = convert;
|
||||
}
|
||||
|
||||
// Create FullyConnected
|
||||
auto fc = std::make_shared<ov::intel_cpu::FullyConnectedNode>(fc_input_a, fc_input_b, output_rank, matmul->get_output_element_type(0));
|
||||
auto output_rank = matmul->get_output_partial_shape(0).rank();
|
||||
auto fc = std::make_shared<ov::intel_cpu::FullyConnectedNode>(fc_input_a, fc_input_b, output_rank,
|
||||
matmul->get_output_element_type(0));
|
||||
fc->set_friendly_name(matmul->get_friendly_name());
|
||||
new_ops.push_back(fc);
|
||||
ngraph::copy_runtime_info(matmul, new_ops);
|
||||
|
@ -253,7 +253,7 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, con
|
||||
ov::PartialShape matmul_shape;
|
||||
for (const auto &parent_out : node->input_values()) {
|
||||
const auto parent = parent_out.get_node_shared_ptr();
|
||||
if (ov::is_type<ov::op::v0::Constant>(parent)) {
|
||||
if (ov::is_type<ov::op::v0::Constant>(parent) || ov::is_type<ov::op::v0::Convert>(parent)) {
|
||||
bias_shape = parent_out.get_shape();
|
||||
num_non_const_inputs++;
|
||||
} else {
|
||||
@ -264,7 +264,8 @@ bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, con
|
||||
// first check that weights are constant and both activations and weights have static shape
|
||||
if (grandparents.size() == 2 &&
|
||||
grandparents[1].get_partial_shape().is_static() &&
|
||||
ov::is_type<ov::op::v0::Constant>(grandparents[1].get_node_shared_ptr())) {
|
||||
(ov::is_type<ov::op::v0::Constant>(grandparents[1].get_node_shared_ptr())
|
||||
|| ov::is_type<ov::op::v0::Convert>(grandparents[1].get_node_shared_ptr()))) {
|
||||
auto rank_a = grandparents[0].get_partial_shape().rank().get_length();
|
||||
auto rank_w = grandparents[1].get_partial_shape().rank().get_length();
|
||||
if (rank_a != 1 && rank_w != 1 && rank_a <= 3 && rank_w <= 3)
|
||||
|
@ -202,8 +202,7 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
|
||||
manager.set_per_pass_validation(false);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::InitNodeInfo);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkShapeOfSubgraphs);
|
||||
// todo: uncomment KeepConstAndDecompression when xxx-105060 is ready
|
||||
// CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompression);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompressionForMatMul);
|
||||
|
||||
const bool useLpt = !defaultPrecisions.empty();
|
||||
if (useLpt) {
|
||||
@ -464,6 +463,13 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
|
||||
ov::pass::ConvertQuantizeDequantize);
|
||||
}
|
||||
|
||||
/* In some cases, during the transformation pipeline, some MatMul nodes can be transformed into other nodes. For example, they can become part of
|
||||
AUGRUCell node (see AUGRUCellFusion pass). In such cases, some constant paths will be unfolded, which can lead to crashes in the plugin. To avoid this,
|
||||
we re-mark decompression converts again and finally do CF for those constant paths that are not inputs to MatMul node */
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::EnableDecompressionConvertConstantFolding);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::KeepConstAndDecompressionForMatMul);
|
||||
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConstantFolding);
|
||||
|
||||
manager.run_passes(model);
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,217 @@
|
||||
// Copyright (C) 2023 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "test_utils/fusing_test_utils.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "shared_test_classes/base/ov_subgraph.hpp"
|
||||
#include "transformations/rt_info/decompression.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
using namespace InferenceEngine;
|
||||
using namespace CPUTestUtils;
|
||||
using namespace ov::test;
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
/* This test checks that the ConvertMatMulToFC transformation should work and the MatMul node is converted to the FC node.
|
||||
* The Convert node should be removed on the CPU plugin side.
|
||||
|
||||
* Graph before:
|
||||
------------ ------------
|
||||
|Input(f32)| |Input(f16)|
|
||||
------------ ------------
|
||||
| |
|
||||
| ---------------------------------
|
||||
| |Convert(decompression f16->f32)|
|
||||
| ---------------------------------
|
||||
| |
|
||||
-----------------------------------------------
|
||||
| MatMul |
|
||||
-----------------------------------------------
|
||||
|
|
||||
--------
|
||||
|Output|
|
||||
--------
|
||||
|
||||
* Exec graph:
|
||||
------------ ------------
|
||||
|Input(f32)| |Input(f16)|
|
||||
------------ ------------
|
||||
| |
|
||||
----------------------------
|
||||
| FullyConnected |
|
||||
----------------------------
|
||||
|
|
||||
--------
|
||||
|Output|
|
||||
--------
|
||||
*/
|
||||
|
||||
using MatMulDecompressConvertParams = std::tuple<
|
||||
std::vector<InputShape>, // input shapes
|
||||
std::pair<bool, bool>, // transposeA, transposeB
|
||||
std::map<std::string, std::string> // additional config
|
||||
>;
|
||||
|
||||
class MatMulDecompressConvertTest : public testing::WithParamInterface<MatMulDecompressConvertParams>,
|
||||
virtual public SubgraphBaseTest, public CPUTestsBase {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<MatMulDecompressConvertParams> obj) {
|
||||
std::vector<InputShape> inputShapes;
|
||||
std::pair<bool, bool> transpose;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
|
||||
std::tie(inputShapes, transpose, additionalConfig) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
for (const auto& shape : inputShapes) {
|
||||
result << CommonTestUtils::partialShape2str({shape.first}) << "_";
|
||||
}
|
||||
result << "TS=";
|
||||
for (const auto& shape : inputShapes) {
|
||||
result << "(";
|
||||
if (!shape.second.empty()) {
|
||||
auto itr = shape.second.begin();
|
||||
do {
|
||||
result << CommonTestUtils::vec2str(*itr);
|
||||
} while (++itr != shape.second.end() && result << "_");
|
||||
}
|
||||
result << ")_";
|
||||
}
|
||||
result << "transpose_a=" << transpose.first << "_";
|
||||
result << "transpose_b=" << transpose.second << "_";
|
||||
|
||||
result << "config=(";
|
||||
for (const auto& configEntry : additionalConfig) {
|
||||
result << configEntry.first << ", " << configEntry.second << ":";
|
||||
}
|
||||
result << ")";
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
protected:
|
||||
template<typename T>
|
||||
void transposeShape(T& shape) {
|
||||
IE_ASSERT(shape.size() > 1);
|
||||
std::swap(*(shape.end() - 1), *(shape.end() - 2));
|
||||
}
|
||||
|
||||
void CheckConstFP16() const {
|
||||
auto getExecValue = [](const ov::Node::RTMap& rtInfo, const std::string ¶mName) -> std::string {
|
||||
auto it = rtInfo.find(paramName);
|
||||
IE_ASSERT(rtInfo.end() != it);
|
||||
return it->second.as<std::string>();
|
||||
};
|
||||
|
||||
const auto execFunction = compiledModel.get_runtime_model();
|
||||
ASSERT_NE(nullptr, execFunction);
|
||||
for (const auto &fcNode : execFunction->get_ops()) {
|
||||
if (getExecValue(fcNode->get_rt_info(), ExecGraphInfoSerialization::LAYER_TYPE) == "FullyConnected") {
|
||||
const auto &constNode = fcNode->get_input_node_shared_ptr(1);
|
||||
ASSERT_EQ(getExecValue(constNode->get_rt_info(), ExecGraphInfoSerialization::LAYER_TYPE), "Const");
|
||||
ASSERT_EQ(getExecValue(constNode->get_rt_info(), ExecGraphInfoSerialization::OUTPUT_PRECISIONS), "FP16");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void SetUp() override {
|
||||
targetDevice = CommonTestUtils::DEVICE_CPU;
|
||||
|
||||
std::vector<InputShape> inputShapes;
|
||||
std::pair<bool, bool> transpose;
|
||||
std::map<std::string, std::string> additionalConfig;
|
||||
|
||||
std::tie(inputShapes, transpose, additionalConfig) = this->GetParam();
|
||||
|
||||
init_input_shapes(inputShapes);
|
||||
|
||||
bool transpA = transpose.first;
|
||||
bool transpB = transpose.second;
|
||||
|
||||
if (transpA) {
|
||||
transposeShape(inputDynamicShapes[0]);
|
||||
for (auto& shapes : targetStaticShapes) {
|
||||
transposeShape(shapes[0]);
|
||||
}
|
||||
}
|
||||
if (transpB) {
|
||||
transposeShape(inputDynamicShapes[1]);
|
||||
for (auto& shapes : targetStaticShapes) {
|
||||
transposeShape(shapes[1]);
|
||||
}
|
||||
}
|
||||
|
||||
const auto& inShapeA = inputDynamicShapes[0];
|
||||
const auto& inShapeB = inputDynamicShapes[1];
|
||||
|
||||
configuration.insert(additionalConfig.begin(), additionalConfig.end());
|
||||
|
||||
ElementType netType = element::f32;
|
||||
if (additionalConfig[PluginConfigParams::KEY_ENFORCE_BF16] == PluginConfigParams::YES)
|
||||
inType = outType = netType = ElementType::bf16;
|
||||
else
|
||||
inType = outType = netType;
|
||||
|
||||
std::string cpuNodeType = "FullyConnected";
|
||||
|
||||
auto params = builder::makeDynamicParams(inType, {inShapeA});
|
||||
auto paramOuts = helpers::convert2OutputVector(helpers::castOps2Nodes<opset1::Parameter>(params));
|
||||
|
||||
auto matrixB = ngraph::builder::makeConstant<float16>(element::f16, inShapeB.get_shape(), {}, true);
|
||||
auto convert = std::make_shared<ngraph::opset1::Convert>(matrixB, inType);
|
||||
mark_as_decompression(convert);
|
||||
auto matMul = builder::makeMatMul(paramOuts[0], convert, transpA, transpB);
|
||||
|
||||
function = CPUTestsBase::makeNgraphFunction(netType, params, matMul, cpuNodeType);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(MatMulDecompressConvertTest, CompareWithRefs) {
|
||||
run();
|
||||
CheckNumberOfNodesWithType(compiledModel, "FullyConnected", 1);
|
||||
CheckNumberOfNodesWithType(compiledModel, "Convert", 0);
|
||||
CheckNumberOfNodesWithType(compiledModel, "Reorder", 0);
|
||||
CheckConstFP16();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
const std::vector<std::pair<bool, bool>> transposeParams = {
|
||||
{false, false},
|
||||
{false, true},
|
||||
{true, false},
|
||||
{true, true},
|
||||
};
|
||||
|
||||
std::vector<std::map<std::string, std::string>> filterAdditionalConfig() {
|
||||
std::vector<std::map<std::string, std::string>> additionalConfig;
|
||||
additionalConfig.push_back(std::map<std::string, std::string>{/* empty config */});
|
||||
if (with_cpu_x86_avx512_core()) {
|
||||
additionalConfig.push_back({{PluginConfigParams::KEY_ENFORCE_BF16, PluginConfigParams::YES}});
|
||||
}
|
||||
|
||||
return additionalConfig;
|
||||
}
|
||||
|
||||
const auto testParams2D_smoke = ::testing::Combine(
|
||||
::testing::Values(static_shapes_to_test_representation({{2, 3}, {3, 4}})),
|
||||
::testing::ValuesIn(transposeParams),
|
||||
::testing::ValuesIn(filterAdditionalConfig()));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_FC_2D, MatMulDecompressConvertTest, testParams2D_smoke,
|
||||
MatMulDecompressConvertTest::getTestCaseName);
|
||||
|
||||
const auto testParams3D_smoke = ::testing::Combine(
|
||||
::testing::Values(static_shapes_to_test_representation({{1, 2, 3}, {3, 4}}),
|
||||
static_shapes_to_test_representation({{2, 3}, {1, 3, 4}})),
|
||||
::testing::ValuesIn(transposeParams),
|
||||
::testing::ValuesIn(filterAdditionalConfig()));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_FC_3D, MatMulDecompressConvertTest, testParams3D_smoke,
|
||||
MatMulDecompressConvertTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
@ -19,6 +19,7 @@
|
||||
#include <ov_ops/type_relaxed.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "transformations/rt_info/decompression.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ov::intel_cpu;
|
||||
@ -318,3 +319,47 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest_second_input_rank_adj_3_witho
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_decompress_convert_0) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f16, ngraph::Shape{ 1, 2, 2 }, { 1 });
|
||||
auto convert = std::make_shared<ngraph::opset1::Convert>(input2, ngraph::element::f32);
|
||||
ov::mark_as_decompression(convert);
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, convert, false, false);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f16, ngraph::Shape{2, 2 }, { 1 });
|
||||
auto convert = std::make_shared<ngraph::opset1::Convert>(input2, ngraph::element::f32);
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(input1, convert, ngraph::Rank(3));
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_decompress_convert_1) {
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f16, ngraph::Shape{ 1, 2, 2 }, { 1 });
|
||||
auto convert = std::make_shared<ngraph::opset1::Convert>(input2, ngraph::element::f32);
|
||||
ov::mark_as_decompression(convert);
|
||||
auto matmul = std::make_shared<ngraph::opset1::MatMul>(input1, convert, true, false);
|
||||
|
||||
function = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
manager.register_pass<ConvertMatMulToFC>();
|
||||
}
|
||||
{
|
||||
auto input1 = std::make_shared<ngraph::opset1::Parameter>(ngraph::element::f32, ngraph::Shape{ 3, 2, 2 });
|
||||
auto transpose_constant = ngraph::opset1::Constant::create(ngraph::element::i64, ngraph::Shape{ 3 }, { 0, 2, 1 });
|
||||
auto transpose = std::make_shared<ngraph::opset1::Transpose>(input1, transpose_constant);
|
||||
auto input2 = ngraph::opset1::Constant::create(ngraph::element::f16, ngraph::Shape{2, 2 }, { 1 });
|
||||
auto convert = std::make_shared<ngraph::opset1::Convert>(input2, ngraph::element::f32);
|
||||
auto matmul = std::make_shared<FullyConnectedNode>(transpose, convert, ngraph::Rank(3));
|
||||
|
||||
function_ref = std::make_shared<ngraph::Function>(ngraph::NodeVector{ matmul }, ngraph::ParameterVector{ input1 });
|
||||
}
|
||||
}
|
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
2
src/plugins/intel_cpu/thirdparty/onednn
vendored
@ -1 +1 @@
|
||||
Subproject commit 33bb2b261d3829162395aaa9bbe8c1c5b139e855
|
||||
Subproject commit 3efb012c7f96b2d5e47270632cdf4b9e4b79b1b8
|
Loading…
Reference in New Issue
Block a user