Add MatMulConstTransposesExtraction transformation (#10412)

Transformation insert Transpose for MatMul's weights and
sets its transpose_b attribute to true.
If executed by MO, it helps to reduce LoadNetwork time on CPU plugin,
since ConvertMatMulToFC doesn't have to insert Transpose by itself.

Ticket: 78635
This commit is contained in:
Mateusz Tabaka 2022-02-21 16:08:28 +01:00 committed by GitHub
parent 4decf16927
commit 6bb8701651
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 431 additions and 8 deletions

View File

@ -0,0 +1,25 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
/**
* @ingroup ie_transformation_common_api
* @brief Resolves transpose_b key from MatMul operation if corresponding input is constant or FakeQuantize by inserting Transpose
*/
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API MatMulConstTransposesExtraction: public MatcherPass {
public:
NGRAPH_RTTI_DECLARATION;
MatMulConstTransposesExtraction();
};
} // namespace pass
} // namespace ngraph

View File

@ -0,0 +1,51 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp"
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/validation_util.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::MatMulConstTransposesExtraction, "MatMulConstTransposesExtraction", 0);
ngraph::pass::MatMulConstTransposesExtraction::MatMulConstTransposesExtraction() {
auto data_pattern = pattern::any_input();
auto weights_pattern = pattern::wrap_type<opset8::Constant,
opset8::FakeQuantize>([](Output<Node> node) -> bool {
const auto& pshape = node.get_partial_shape();
const auto& rank = pshape.rank();
return rank.is_static() && rank.get_length() >= 2 &&
std::count(pshape.begin(), pshape.end(), 1) >= rank.get_length() - 2;
});
auto matmul_pattern = pattern::wrap_type<opset8::MatMul>({data_pattern, weights_pattern});
matcher_pass_callback callback = [=](pattern::Matcher& m) {
auto node = m.get_match_root();
auto matmul = as_type<opset8::MatMul>(node.get());
if (!matmul || matmul->get_transpose_b())
return false;
const auto& pattern_value_map = m.get_pattern_value_map();
const auto& weights = pattern_value_map.at(weights_pattern);
std::vector<int> transpose_order(weights.get_partial_shape().size());
std::iota(transpose_order.begin(), transpose_order.end(), 0);
std::reverse(transpose_order.end() - 2, transpose_order.end());
std::shared_ptr<Node> transpose = std::make_shared<opset8::Transpose>(weights,
op::Constant::create(element::i32, {transpose_order.size()}, transpose_order));
if (ov::is_type<op::Constant>(weights.get_node())) {
if (auto constant = get_constant_from_source(transpose))
transpose = constant;
}
auto new_matmul = std::make_shared<opset8::MatMul>(pattern_value_map.at(data_pattern), transpose, matmul->get_transpose_a(), true);
new_matmul->set_friendly_name(matmul->get_friendly_name());
copy_runtime_info(node, {transpose, new_matmul});
replace_node(node, new_matmul);
return true;
};
auto m = std::make_shared<pattern::Matcher>(matmul_pattern, "MatMulConstTransposesExtraction");
this->register_matcher(m, callback);
}

View File

@ -58,7 +58,8 @@
#include <transformations/common_optimizations/nearest_neighbor_upsampling_fusion.hpp>
#include <transformations/common_optimizations/ric_fusion.hpp>
#include <transformations/common_optimizations/matmul_multiply_fusion.hpp>
#include "transformations/common_optimizations/align_eltwise_input_ranks.hpp"
#include <transformations/common_optimizations/align_eltwise_input_ranks.hpp>
#include <transformations/common_optimizations/matmul_const_transposes_extraction.hpp>
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
@ -155,6 +156,7 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
common_fusions->add_matcher<ngraph::pass::SubtractFusion>();
common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
common_fusions->add_matcher<ngraph::pass::ReshapeSequenceFusion>(m_use_shapes);
common_fusions->add_matcher<ngraph::pass::MatMulConstTransposesExtraction>();
common_fusions->set_name("ngraph::pass::CommonFusions");
manager.register_pass<ngraph::pass::BinarizeWeights>();

View File

@ -31,11 +31,12 @@ std::string ov::getPrimitivesPriority(const std::shared_ptr<ngraph::Node>& node)
}
Any PrimitivesPriority::merge(const ngraph::NodeVector& nodes) const {
auto isConvolutionBased = [](const std::shared_ptr<Node>& node) -> bool {
auto canBeMerged = [](const std::shared_ptr<Node>& node) -> bool {
if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolutionBackpropData>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData>(node)) {
std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData>(node) ||
std::dynamic_pointer_cast<ngraph::opset1::MatMul>(node)) {
return true;
}
return false;
@ -44,7 +45,7 @@ Any PrimitivesPriority::merge(const ngraph::NodeVector& nodes) const {
std::set<std::string> unique_pp;
for (auto& node : nodes) {
if (isConvolutionBased(node)) {
if (canBeMerged(node)) {
std::string pp = getPrimitivesPriority(node);
if (!pp.empty())
unique_pp.insert(pp);

View File

@ -409,7 +409,7 @@ TEST_F(NGraphReaderTests, MatMulBiasFusionNoBroadcast) {
</output>
</layer>
<layer id="3" name="add" precision="FP32" type="FullyConnected">
<data alpha="0" beta="0" out-size="1000" originalLayersNames="add,fc"/>
<data alpha="0" beta="0" out-size="1000" originalLayersNames="add,fc,weights"/>
<input>
<port id="0">
<dim>1</dim>

View File

@ -641,7 +641,7 @@ TEST_F(NGraphReaderTests, ReadMatMul1DNetwork) {
</blobs>
</layer>
<layer name="fc/Reshape" type="Reshape" precision="FP32" id="2">
<data dim="" originalLayersNames="fc" />
<data dim="" originalLayersNames="embedded_input__const,fc" />
<input>
<port id="0">
<dim>2048</dim>
@ -658,7 +658,7 @@ TEST_F(NGraphReaderTests, ReadMatMul1DNetwork) {
</output>
</layer>
<layer name="FullyConnected_737" type="FullyConnected" precision="FP32" id="3">
<data originalLayersNames="fc" out-size="1000" />
<data originalLayersNames="embedded_input__const,fc" out-size="1000" />
<input>
<port id="0">
<dim>1</dim>
@ -687,7 +687,7 @@ TEST_F(NGraphReaderTests, ReadMatMul1DNetwork) {
</blobs>
</layer>
<layer name="fc" type="Reshape" precision="FP32" id="5">
<data dim="" originalLayersNames="fc" />
<data dim="" originalLayersNames="embedded_input__const,fc" />
<input>
<port id="0">
<dim>1</dim>

View File

@ -0,0 +1,88 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <transformations/common_optimizations/matmul_const_transposes_extraction.hpp>
#include <ngraph/pass/manager.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
using namespace ngraph;
TEST_F(TransformationTestsF, MatMulConstTransposesExtractionConstantWeights) {
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 3, 4});
auto weights = opset8::Constant::create(element::f32, Shape{1, 3, 2}, {1, 2, 3, 4, 5, 6});
auto matmul = std::make_shared<opset8::MatMul>(data, weights, true);
function = std::make_shared<Function>(matmul, ParameterVector{data});
manager.register_pass<pass::MatMulConstTransposesExtraction>();
}
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 3, 4});
auto weights = opset8::Constant::create(element::f32, Shape{1, 2, 3}, {1, 3, 5, 2, 4, 6});
auto matmul = std::make_shared<opset8::MatMul>(data, weights, true, true);
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
enable_accuracy_check();
}
TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) {
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 4, 3});
auto weights = opset8::Constant::create(element::f32, Shape{1, 3, 2}, {1, 2, 3, 4, 5, 6});
auto low = opset8::Constant::create(element::f32, Shape{1}, {0});
auto high = opset8::Constant::create(element::f32, Shape{1}, {10});
auto fq = std::make_shared<opset8::FakeQuantize>(weights, low, high, low, high, 255);
auto matmul = std::make_shared<opset8::MatMul>(data, fq);
function = std::make_shared<Function>(matmul, ParameterVector{data});
manager.register_pass<pass::MatMulConstTransposesExtraction>();
}
{
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 4, 3});
auto weights = opset8::Constant::create(element::f32, Shape{1, 3, 2}, {1, 2, 3, 4, 5, 6});
auto low = opset8::Constant::create(element::f32, Shape{1}, {0});
auto high = opset8::Constant::create(element::f32, Shape{1}, {10});
auto fq = std::make_shared<opset8::FakeQuantize>(weights, low, high, low, high, 255);
auto transpose = std::make_shared<opset8::Transpose>(fq, op::Constant::create(element::i32, Shape{3}, {0, 2, 1}));
auto matmul = std::make_shared<opset8::MatMul>(data, transpose, false, true);
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
}
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
enable_accuracy_check();
}
TEST_F(TransformationTestsF, NegativeMatMulConstTransposesExtractionInvalidRank) {
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 3, 4});
auto weights = opset8::Constant::create(element::f32, Shape{3}, {1, 2, 3});
auto matmul = std::make_shared<opset8::MatMul>(data, weights, true);
function = std::make_shared<Function>(matmul, ParameterVector{data});
manager.register_pass<pass::MatMulConstTransposesExtraction>();
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, NegativeMatMulConstTransposesExtractionTransposeBSet) {
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 3, 4});
auto weights = opset8::Constant::create(element::f32, Shape{1, 2, 3}, {1, 2, 3, 4, 5, 6});
auto matmul = std::make_shared<opset8::MatMul>(data, weights, true, true);
function = std::make_shared<Function>(matmul, ParameterVector{data});
manager.register_pass<pass::MatMulConstTransposesExtraction>();
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}
TEST_F(TransformationTestsF, NegativeMatMulConstTransposesExtractionNonUnitDims) {
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 3, 4});
auto weights = opset8::Constant::create(element::f32, Shape{2, 3, 2}, {1, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7});
auto matmul = std::make_shared<opset8::MatMul>(data, weights, true);
function = std::make_shared<Function>(matmul, ParameterVector{data});
manager.register_pass<pass::MatMulConstTransposesExtraction>();
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
}

View File

@ -0,0 +1,73 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "subgraph_tests/matmul_const_transposes_extraction.hpp"
using namespace SubgraphTestsDefinitions;
namespace {
std::vector<MatMulConstTransposesExtractionTestShapeParams> shape_params = {
{{2, 2}, {2, 3}, false},
{{5}, {5, 1}, false},
{{5}, {5, 3}, false},
{{5, 10}, {10, 7}, false},
{{5, 10}, {1, 10, 7}, false},
{{5, 10}, {1, 1, 10, 7}, false},
{{2, 3, 5, 10}, {10, 7}, false},
{{2, 3, 5, 10}, {1, 10, 7}, false},
{{2, 3, 5, 10}, {1, 10, 1}, false},
{{2, 3, 5, 10}, {1, 1, 10, 7}, false},
{{2, 3, 5, 10}, {1, 1, 10, 1}, false},
};
INSTANTIATE_TEST_SUITE_P(smoke_MatMulConstTransposesExtractionTest, MatMulConstTransposesExtractionTest,
::testing::Combine(
::testing::ValuesIn(shape_params),
::testing::Values(true), // can be fused
::testing::Values(CommonTestUtils::DEVICE_CPU)),
MatMulConstTransposesExtractionTest::getTestCaseName);
std::vector<MatMulConstTransposesExtractionTestShapeParams> negative_shape_params = {
{{5}, {5}, false},
{{5}, {3, 5}, true},
{{5, 5}, {5, 5}, true},
{{5, 10}, {7, 10}, true},
{{5, 10}, {2, 10, 7}, false},
{{5, 10}, {2, 3, 10, 7}, false},
{{1, 1, 5, 10}, {10}, false},
{{1, 1, 5, 10}, {7, 10}, true},
{{1, 1, 5, 10}, {1, 1, 7, 10}, true},
{{2, 3, 5, 10}, {7, 10}, true},
{{2, 3, 5, 10}, {3, 7, 10}, true},
{{2, 3, 5, 10}, {2, 3, 7, 10}, true},
{{2, 3, 5, 10}, {3, 10, 7}, false},
{{2, 3, 5, 10}, {1, 3, 10, 7}, false},
{{2, 3, 5, 10}, {2, 3, 10, 7}, false},
};
INSTANTIATE_TEST_SUITE_P(smoke_NegativeMatMulConstTransposesExtractionTest, MatMulConstTransposesExtractionTest,
::testing::Combine(
::testing::ValuesIn(negative_shape_params),
::testing::Values(false), // cannot be fused
::testing::Values(CommonTestUtils::DEVICE_CPU)),
MatMulConstTransposesExtractionTest::getTestCaseName);
std::vector<MatMulConstTransposesExtractionTestShapeParams> shape_params2 = {
{{2, 2}, {2, 2}, false},
{{5, 10}, {10, 7}, false},
{{5, 10}, {1, 10, 7}, false},
{{5, 10}, {1, 1, 10, 7}, false},
{{2, 3, 5, 10}, {10, 7}, false},
{{2, 3, 5, 10}, {1, 10, 7}, false},
{{2, 3, 5, 10}, {1, 1, 10, 7}, false},
};
INSTANTIATE_TEST_SUITE_P(smoke_QuantizedMatMulConstTransposesExtractionTest, QuantizedMatMulConstTransposesExtractionTest,
::testing::Combine(
::testing::ValuesIn(shape_params2),
::testing::Values(true), // can be fused
::testing::Values(CommonTestUtils::DEVICE_CPU)),
QuantizedMatMulConstTransposesExtractionTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,19 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "shared_test_classes/subgraph/matmul_const_transposes_extraction.hpp"
namespace SubgraphTestsDefinitions {
TEST_P(MatMulConstTransposesExtractionTest, CompareWithRefs) {
Run();
}
TEST_P(QuantizedMatMulConstTransposesExtractionTest, CompareWithRefs) {
Run();
}
} // namespace SubgraphTestsDefinitions

View File

@ -0,0 +1,47 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <tuple>
#include <string>
#include "shared_test_classes/base/layer_test_utils.hpp"
#include <ngraph/shape.hpp>
namespace SubgraphTestsDefinitions {
struct MatMulConstTransposesExtractionTestShapeParams {
ngraph::Shape input_shape;
ngraph::Shape weights_shape;
bool trans_b;
};
typedef std::tuple<
MatMulConstTransposesExtractionTestShapeParams,
bool, // whether Mul can be fused to MatMul in this case
std::string // Device name
> MatMulConstTransposesExtractionTestParams;
class MatMulConstTransposesExtractionTest
: public testing::WithParamInterface<MatMulConstTransposesExtractionTestParams>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<MatMulConstTransposesExtractionTestParams> &obj);
protected:
void SetUp() override;
};
class QuantizedMatMulConstTransposesExtractionTest
: public testing::WithParamInterface<MatMulConstTransposesExtractionTestParams>,
virtual public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<MatMulConstTransposesExtractionTestParams> &obj);
protected:
void SetUp() override;
void TearDown() override;
};
} // namespace SubgraphTestsDefinitions

View File

@ -0,0 +1,117 @@
// Copyright (C) 2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp"
#include "shared_test_classes/subgraph/matmul_const_transposes_extraction.hpp"
#include "ngraph_functions/builders.hpp"
#include <exec_graph_info.hpp>
namespace SubgraphTestsDefinitions {
using namespace ngraph;
std::string MatMulConstTransposesExtractionTest::getTestCaseName(const testing::TestParamInfo<MatMulConstTransposesExtractionTestParams> &obj) {
MatMulConstTransposesExtractionTestShapeParams shape_params;
std::string device;
std::tie(shape_params, std::ignore, device) = obj.param;
std::ostringstream results;
results << "input=" << shape_params.input_shape << "_";
results << "weights=" << shape_params.weights_shape << "_";
results << "transB=" << std::boolalpha << shape_params.trans_b << "_";
results << "dev=" << device;
return results.str();
}
void MatMulConstTransposesExtractionTest::SetUp() {
MatMulConstTransposesExtractionTestShapeParams shape_params;
element::Type type = element::f32;
bool can_be_fused;
std::tie(shape_params, can_be_fused, targetDevice) = GetParam();
const auto& input_shape = shape_params.input_shape;
const auto& weights_shape = shape_params.weights_shape;
auto param = std::make_shared<opset8::Parameter>(type, input_shape);
auto weights = opset8::Constant::create(type, weights_shape, {0.5});
auto matmul = std::make_shared<opset8::MatMul>(param, weights, false, shape_params.trans_b);
function = std::make_shared<Function>(matmul, ParameterVector{param});
auto transformed_function = clone_function(*function);
pass::Manager manager;
manager.register_pass<pass::MatMulConstTransposesExtraction>();
manager.run_passes(transformed_function);
bool functions_equal;
auto orig_function = clone_function(*function);
std::tie(functions_equal, std::ignore) = compare_functions(transformed_function, orig_function, true);
if (can_be_fused) {
ASSERT_FALSE(functions_equal);
} else {
ASSERT_TRUE(functions_equal);
}
}
std::string QuantizedMatMulConstTransposesExtractionTest::getTestCaseName(
const testing::TestParamInfo<MatMulConstTransposesExtractionTestParams> &obj) {
MatMulConstTransposesExtractionTestShapeParams params;
std::string device;
std::tie(params, std::ignore, device) = obj.param;
std::ostringstream results;
results << "input=" << params.input_shape << "_"
"weights=" << params.weights_shape << "_"
"dev=" << device;
return results.str();
}
void QuantizedMatMulConstTransposesExtractionTest::SetUp() {
MatMulConstTransposesExtractionTestShapeParams params;
bool can_be_fused;
std::tie(params, can_be_fused, targetDevice) = GetParam();
const auto& input_shape = params.input_shape;
auto weights_shape = params.weights_shape;
element::Type type = element::f32;
auto param = std::make_shared<opset8::Parameter>(type, input_shape);
std::shared_ptr<Node> input;
std::shared_ptr<Node> weights = opset8::Constant::create(type, weights_shape, {0.5});
auto low = opset8::Constant::create(type, {1}, {-2});
auto high = opset8::Constant::create(type, {1}, {2});
input = std::make_shared<opset8::FakeQuantize>(param, low, high, low, high, 256);
weights = std::make_shared<opset8::FakeQuantize>(weights, low, high, low, high, 255);
auto matmul = std::make_shared<opset8::MatMul>(input, weights, false, false);
function = std::make_shared<Function>(matmul, ParameterVector{param});
auto transformed_function = clone_function(*function);
pass::Manager manager;
manager.register_pass<pass::MatMulConstTransposesExtraction>();
manager.run_passes(transformed_function);
bool functions_equal;
auto orig_function = clone_function(*function);
std::tie(functions_equal, std::ignore) = compare_functions(transformed_function, orig_function, true);
if (can_be_fused) {
ASSERT_FALSE(functions_equal);
} else {
ASSERT_TRUE(functions_equal);
}
}
void QuantizedMatMulConstTransposesExtractionTest::TearDown() {
auto runtime_function = executableNetwork.GetExecGraphInfo().getFunction();
int ops_found = 0;
for (const auto& node : runtime_function->get_ordered_ops()) {
const auto& layer_type = node->get_rt_info().at(ExecGraphInfoSerialization::LAYER_TYPE).as<std::string>();
if (layer_type == "FullyConnected" || layer_type == "MatMul") {
ops_found++;
auto inputs = node->input_values();
ASSERT_EQ(element::u8, inputs[0].get_element_type());
ASSERT_EQ(element::i8, inputs[1].get_element_type());
}
}
ASSERT_GT(ops_found, 0);
}
} // namespace SubgraphTestsDefinitions