Don't constantfold weights in MatMulConstTransposesExtraction transformation (#17917)

get_constant_from_source for Transpose node calls evaluate method
twice which is unnecessary in this case.

Ticket: CVS-105967
This commit is contained in:
Mateusz Tabaka 2023-06-11 09:49:21 +02:00 committed by GitHub
parent 50c85f01ab
commit 93689cc417
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 5 additions and 9 deletions

View File

@ -1,4 +1,4 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -33,13 +33,6 @@ ov::pass::MatMulConstTransposesExtraction::MatMulConstTransposesExtraction() {
std::shared_ptr<Node> transpose = std::make_shared<opset8::Transpose>(
weights,
opset8::Constant::create(element::i32, {transpose_order.size()}, transpose_order));
if (ov::is_type<opset8::Constant>(weights.get_node())) {
OPENVINO_SUPPRESS_DEPRECATED_START
if (auto constant = get_constant_from_source(transpose)) {
OPENVINO_SUPPRESS_DEPRECATED_END
transpose = constant;
}
}
auto new_matmul = std::make_shared<opset8::MatMul>(pattern_value_map.at(data_pattern),
transpose,
matmul->get_transpose_a(),

View File

@ -1,4 +1,4 @@
// Copyright (C) 2022 Intel Corporation
// Copyright (C) 2022-2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
@ -7,6 +7,7 @@
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pass/manager.hpp>
#include <openvino/pass/constant_folding.hpp>
#include <transformations/common_optimizations/matmul_const_transposes_extraction.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
@ -21,6 +22,7 @@ TEST_F(TransformationTestsF, MatMulConstTransposesExtractionConstantWeights) {
function = std::make_shared<Function>(matmul, ParameterVector{data});
manager.register_pass<ov::pass::MatMulConstTransposesExtraction>();
manager.register_pass<ov::pass::ConstantFolding>();
}
{
@ -44,6 +46,7 @@ TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) {
function = std::make_shared<Function>(matmul, ParameterVector{data});
manager.register_pass<ov::pass::MatMulConstTransposesExtraction>();
manager.register_pass<ov::pass::ConstantFolding>();
}
{