Add MatMulConstTransposesExtraction transformation (#10412)
Transformation insert Transpose for MatMul's weights and sets its transpose_b attribute to true. If executed by MO, it helps to reduce LoadNetwork time on CPU plugin, since ConvertMatMulToFC doesn't have to insert Transpose by itself. Ticket: 78635
This commit is contained in:
parent
4decf16927
commit
6bb8701651
@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief Resolves transpose_b key from MatMul operation if corresponding input is constant or FakeQuantize by inserting Transpose
|
||||
*/
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API MatMulConstTransposesExtraction: public MatcherPass {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
MatMulConstTransposesExtraction();
|
||||
};
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
@ -0,0 +1,51 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include <ngraph/validation_util.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MatMulConstTransposesExtraction, "MatMulConstTransposesExtraction", 0);
|
||||
|
||||
ngraph::pass::MatMulConstTransposesExtraction::MatMulConstTransposesExtraction() {
|
||||
auto data_pattern = pattern::any_input();
|
||||
auto weights_pattern = pattern::wrap_type<opset8::Constant,
|
||||
opset8::FakeQuantize>([](Output<Node> node) -> bool {
|
||||
const auto& pshape = node.get_partial_shape();
|
||||
const auto& rank = pshape.rank();
|
||||
return rank.is_static() && rank.get_length() >= 2 &&
|
||||
std::count(pshape.begin(), pshape.end(), 1) >= rank.get_length() - 2;
|
||||
});
|
||||
auto matmul_pattern = pattern::wrap_type<opset8::MatMul>({data_pattern, weights_pattern});
|
||||
matcher_pass_callback callback = [=](pattern::Matcher& m) {
|
||||
auto node = m.get_match_root();
|
||||
auto matmul = as_type<opset8::MatMul>(node.get());
|
||||
if (!matmul || matmul->get_transpose_b())
|
||||
return false;
|
||||
|
||||
const auto& pattern_value_map = m.get_pattern_value_map();
|
||||
const auto& weights = pattern_value_map.at(weights_pattern);
|
||||
|
||||
std::vector<int> transpose_order(weights.get_partial_shape().size());
|
||||
std::iota(transpose_order.begin(), transpose_order.end(), 0);
|
||||
std::reverse(transpose_order.end() - 2, transpose_order.end());
|
||||
std::shared_ptr<Node> transpose = std::make_shared<opset8::Transpose>(weights,
|
||||
op::Constant::create(element::i32, {transpose_order.size()}, transpose_order));
|
||||
if (ov::is_type<op::Constant>(weights.get_node())) {
|
||||
if (auto constant = get_constant_from_source(transpose))
|
||||
transpose = constant;
|
||||
}
|
||||
auto new_matmul = std::make_shared<opset8::MatMul>(pattern_value_map.at(data_pattern), transpose, matmul->get_transpose_a(), true);
|
||||
new_matmul->set_friendly_name(matmul->get_friendly_name());
|
||||
copy_runtime_info(node, {transpose, new_matmul});
|
||||
replace_node(node, new_matmul);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<pattern::Matcher>(matmul_pattern, "MatMulConstTransposesExtraction");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -58,7 +58,8 @@
|
||||
#include <transformations/common_optimizations/nearest_neighbor_upsampling_fusion.hpp>
|
||||
#include <transformations/common_optimizations/ric_fusion.hpp>
|
||||
#include <transformations/common_optimizations/matmul_multiply_fusion.hpp>
|
||||
#include "transformations/common_optimizations/align_eltwise_input_ranks.hpp"
|
||||
#include <transformations/common_optimizations/align_eltwise_input_ranks.hpp>
|
||||
#include <transformations/common_optimizations/matmul_const_transposes_extraction.hpp>
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::MOCTransformations, "MOCTransformations", 0);
|
||||
|
||||
@ -155,6 +156,7 @@ bool ngraph::pass::MOCTransformations::run_on_model(const std::shared_ptr<ngraph
|
||||
common_fusions->add_matcher<ngraph::pass::SubtractFusion>();
|
||||
common_fusions->add_matcher<ngraph::pass::TransposeToReshape>();
|
||||
common_fusions->add_matcher<ngraph::pass::ReshapeSequenceFusion>(m_use_shapes);
|
||||
common_fusions->add_matcher<ngraph::pass::MatMulConstTransposesExtraction>();
|
||||
common_fusions->set_name("ngraph::pass::CommonFusions");
|
||||
|
||||
manager.register_pass<ngraph::pass::BinarizeWeights>();
|
||||
|
@ -31,11 +31,12 @@ std::string ov::getPrimitivesPriority(const std::shared_ptr<ngraph::Node>& node)
|
||||
}
|
||||
|
||||
Any PrimitivesPriority::merge(const ngraph::NodeVector& nodes) const {
|
||||
auto isConvolutionBased = [](const std::shared_ptr<Node>& node) -> bool {
|
||||
auto canBeMerged = [](const std::shared_ptr<Node>& node) -> bool {
|
||||
if (std::dynamic_pointer_cast<ngraph::opset1::Convolution>(node) ||
|
||||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolution>(node) ||
|
||||
std::dynamic_pointer_cast<ngraph::opset1::GroupConvolutionBackpropData>(node) ||
|
||||
std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData>(node)) {
|
||||
std::dynamic_pointer_cast<ngraph::opset1::ConvolutionBackpropData>(node) ||
|
||||
std::dynamic_pointer_cast<ngraph::opset1::MatMul>(node)) {
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
@ -44,7 +45,7 @@ Any PrimitivesPriority::merge(const ngraph::NodeVector& nodes) const {
|
||||
std::set<std::string> unique_pp;
|
||||
|
||||
for (auto& node : nodes) {
|
||||
if (isConvolutionBased(node)) {
|
||||
if (canBeMerged(node)) {
|
||||
std::string pp = getPrimitivesPriority(node);
|
||||
if (!pp.empty())
|
||||
unique_pp.insert(pp);
|
||||
|
@ -409,7 +409,7 @@ TEST_F(NGraphReaderTests, MatMulBiasFusionNoBroadcast) {
|
||||
</output>
|
||||
</layer>
|
||||
<layer id="3" name="add" precision="FP32" type="FullyConnected">
|
||||
<data alpha="0" beta="0" out-size="1000" originalLayersNames="add,fc"/>
|
||||
<data alpha="0" beta="0" out-size="1000" originalLayersNames="add,fc,weights"/>
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
|
@ -641,7 +641,7 @@ TEST_F(NGraphReaderTests, ReadMatMul1DNetwork) {
|
||||
</blobs>
|
||||
</layer>
|
||||
<layer name="fc/Reshape" type="Reshape" precision="FP32" id="2">
|
||||
<data dim="" originalLayersNames="fc" />
|
||||
<data dim="" originalLayersNames="embedded_input__const,fc" />
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>2048</dim>
|
||||
@ -658,7 +658,7 @@ TEST_F(NGraphReaderTests, ReadMatMul1DNetwork) {
|
||||
</output>
|
||||
</layer>
|
||||
<layer name="FullyConnected_737" type="FullyConnected" precision="FP32" id="3">
|
||||
<data originalLayersNames="fc" out-size="1000" />
|
||||
<data originalLayersNames="embedded_input__const,fc" out-size="1000" />
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
@ -687,7 +687,7 @@ TEST_F(NGraphReaderTests, ReadMatMul1DNetwork) {
|
||||
</blobs>
|
||||
</layer>
|
||||
<layer name="fc" type="Reshape" precision="FP32" id="5">
|
||||
<data dim="" originalLayersNames="fc" />
|
||||
<data dim="" originalLayersNames="embedded_input__const,fc" />
|
||||
<input>
|
||||
<port id="0">
|
||||
<dim>1</dim>
|
||||
|
@ -0,0 +1,88 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <transformations/common_optimizations/matmul_const_transposes_extraction.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
TEST_F(TransformationTestsF, MatMulConstTransposesExtractionConstantWeights) {
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 3, 4});
|
||||
auto weights = opset8::Constant::create(element::f32, Shape{1, 3, 2}, {1, 2, 3, 4, 5, 6});
|
||||
auto matmul = std::make_shared<opset8::MatMul>(data, weights, true);
|
||||
function = std::make_shared<Function>(matmul, ParameterVector{data});
|
||||
|
||||
manager.register_pass<pass::MatMulConstTransposesExtraction>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 3, 4});
|
||||
auto weights = opset8::Constant::create(element::f32, Shape{1, 2, 3}, {1, 3, 5, 2, 4, 6});
|
||||
auto matmul = std::make_shared<opset8::MatMul>(data, weights, true, true);
|
||||
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
enable_accuracy_check();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, MatMulConstTransposesExtractionFQOnWeights) {
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 4, 3});
|
||||
auto weights = opset8::Constant::create(element::f32, Shape{1, 3, 2}, {1, 2, 3, 4, 5, 6});
|
||||
auto low = opset8::Constant::create(element::f32, Shape{1}, {0});
|
||||
auto high = opset8::Constant::create(element::f32, Shape{1}, {10});
|
||||
auto fq = std::make_shared<opset8::FakeQuantize>(weights, low, high, low, high, 255);
|
||||
auto matmul = std::make_shared<opset8::MatMul>(data, fq);
|
||||
function = std::make_shared<Function>(matmul, ParameterVector{data});
|
||||
|
||||
manager.register_pass<pass::MatMulConstTransposesExtraction>();
|
||||
}
|
||||
|
||||
{
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 4, 3});
|
||||
auto weights = opset8::Constant::create(element::f32, Shape{1, 3, 2}, {1, 2, 3, 4, 5, 6});
|
||||
auto low = opset8::Constant::create(element::f32, Shape{1}, {0});
|
||||
auto high = opset8::Constant::create(element::f32, Shape{1}, {10});
|
||||
auto fq = std::make_shared<opset8::FakeQuantize>(weights, low, high, low, high, 255);
|
||||
auto transpose = std::make_shared<opset8::Transpose>(fq, op::Constant::create(element::i32, Shape{3}, {0, 2, 1}));
|
||||
auto matmul = std::make_shared<opset8::MatMul>(data, transpose, false, true);
|
||||
function_ref = std::make_shared<Function>(matmul, ParameterVector{data});
|
||||
}
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
enable_accuracy_check();
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, NegativeMatMulConstTransposesExtractionInvalidRank) {
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 3, 4});
|
||||
auto weights = opset8::Constant::create(element::f32, Shape{3}, {1, 2, 3});
|
||||
auto matmul = std::make_shared<opset8::MatMul>(data, weights, true);
|
||||
function = std::make_shared<Function>(matmul, ParameterVector{data});
|
||||
manager.register_pass<pass::MatMulConstTransposesExtraction>();
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, NegativeMatMulConstTransposesExtractionTransposeBSet) {
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 3, 4});
|
||||
auto weights = opset8::Constant::create(element::f32, Shape{1, 2, 3}, {1, 2, 3, 4, 5, 6});
|
||||
auto matmul = std::make_shared<opset8::MatMul>(data, weights, true, true);
|
||||
function = std::make_shared<Function>(matmul, ParameterVector{data});
|
||||
manager.register_pass<pass::MatMulConstTransposesExtraction>();
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
||||
|
||||
TEST_F(TransformationTestsF, NegativeMatMulConstTransposesExtractionNonUnitDims) {
|
||||
auto data = std::make_shared<opset8::Parameter>(element::f32, Shape{1, 3, 4});
|
||||
auto weights = opset8::Constant::create(element::f32, Shape{2, 3, 2}, {1, 2, 3, 4, 5, 6, 2, 3, 4, 5, 6, 7});
|
||||
auto matmul = std::make_shared<opset8::MatMul>(data, weights, true);
|
||||
function = std::make_shared<Function>(matmul, ParameterVector{data});
|
||||
manager.register_pass<pass::MatMulConstTransposesExtraction>();
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
}
|
@ -0,0 +1,73 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "subgraph_tests/matmul_const_transposes_extraction.hpp"
|
||||
|
||||
using namespace SubgraphTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
std::vector<MatMulConstTransposesExtractionTestShapeParams> shape_params = {
|
||||
{{2, 2}, {2, 3}, false},
|
||||
{{5}, {5, 1}, false},
|
||||
{{5}, {5, 3}, false},
|
||||
{{5, 10}, {10, 7}, false},
|
||||
{{5, 10}, {1, 10, 7}, false},
|
||||
{{5, 10}, {1, 1, 10, 7}, false},
|
||||
{{2, 3, 5, 10}, {10, 7}, false},
|
||||
{{2, 3, 5, 10}, {1, 10, 7}, false},
|
||||
{{2, 3, 5, 10}, {1, 10, 1}, false},
|
||||
{{2, 3, 5, 10}, {1, 1, 10, 7}, false},
|
||||
{{2, 3, 5, 10}, {1, 1, 10, 1}, false},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_MatMulConstTransposesExtractionTest, MatMulConstTransposesExtractionTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(shape_params),
|
||||
::testing::Values(true), // can be fused
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
MatMulConstTransposesExtractionTest::getTestCaseName);
|
||||
|
||||
std::vector<MatMulConstTransposesExtractionTestShapeParams> negative_shape_params = {
|
||||
{{5}, {5}, false},
|
||||
{{5}, {3, 5}, true},
|
||||
{{5, 5}, {5, 5}, true},
|
||||
{{5, 10}, {7, 10}, true},
|
||||
{{5, 10}, {2, 10, 7}, false},
|
||||
{{5, 10}, {2, 3, 10, 7}, false},
|
||||
{{1, 1, 5, 10}, {10}, false},
|
||||
{{1, 1, 5, 10}, {7, 10}, true},
|
||||
{{1, 1, 5, 10}, {1, 1, 7, 10}, true},
|
||||
{{2, 3, 5, 10}, {7, 10}, true},
|
||||
{{2, 3, 5, 10}, {3, 7, 10}, true},
|
||||
{{2, 3, 5, 10}, {2, 3, 7, 10}, true},
|
||||
{{2, 3, 5, 10}, {3, 10, 7}, false},
|
||||
{{2, 3, 5, 10}, {1, 3, 10, 7}, false},
|
||||
{{2, 3, 5, 10}, {2, 3, 10, 7}, false},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_NegativeMatMulConstTransposesExtractionTest, MatMulConstTransposesExtractionTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(negative_shape_params),
|
||||
::testing::Values(false), // cannot be fused
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
MatMulConstTransposesExtractionTest::getTestCaseName);
|
||||
|
||||
std::vector<MatMulConstTransposesExtractionTestShapeParams> shape_params2 = {
|
||||
{{2, 2}, {2, 2}, false},
|
||||
{{5, 10}, {10, 7}, false},
|
||||
{{5, 10}, {1, 10, 7}, false},
|
||||
{{5, 10}, {1, 1, 10, 7}, false},
|
||||
{{2, 3, 5, 10}, {10, 7}, false},
|
||||
{{2, 3, 5, 10}, {1, 10, 7}, false},
|
||||
{{2, 3, 5, 10}, {1, 1, 10, 7}, false},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_QuantizedMatMulConstTransposesExtractionTest, QuantizedMatMulConstTransposesExtractionTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(shape_params2),
|
||||
::testing::Values(true), // can be fused
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)),
|
||||
QuantizedMatMulConstTransposesExtractionTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
@ -0,0 +1,19 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "shared_test_classes/subgraph/matmul_const_transposes_extraction.hpp"
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
TEST_P(MatMulConstTransposesExtractionTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
TEST_P(QuantizedMatMulConstTransposesExtractionTest, CompareWithRefs) {
|
||||
Run();
|
||||
}
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
@ -0,0 +1,47 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <tuple>
|
||||
#include <string>
|
||||
#include "shared_test_classes/base/layer_test_utils.hpp"
|
||||
#include <ngraph/shape.hpp>
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
struct MatMulConstTransposesExtractionTestShapeParams {
|
||||
ngraph::Shape input_shape;
|
||||
ngraph::Shape weights_shape;
|
||||
bool trans_b;
|
||||
};
|
||||
|
||||
typedef std::tuple<
|
||||
MatMulConstTransposesExtractionTestShapeParams,
|
||||
bool, // whether Mul can be fused to MatMul in this case
|
||||
std::string // Device name
|
||||
> MatMulConstTransposesExtractionTestParams;
|
||||
|
||||
class MatMulConstTransposesExtractionTest
|
||||
: public testing::WithParamInterface<MatMulConstTransposesExtractionTestParams>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<MatMulConstTransposesExtractionTestParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
class QuantizedMatMulConstTransposesExtractionTest
|
||||
: public testing::WithParamInterface<MatMulConstTransposesExtractionTestParams>,
|
||||
virtual public LayerTestsUtils::LayerTestsCommon {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<MatMulConstTransposesExtractionTestParams> &obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
void TearDown() override;
|
||||
};
|
||||
|
||||
} // namespace SubgraphTestsDefinitions
|
@ -0,0 +1,117 @@
|
||||
// Copyright (C) 2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/matmul_const_transposes_extraction.hpp"
|
||||
#include "shared_test_classes/subgraph/matmul_const_transposes_extraction.hpp"
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include <exec_graph_info.hpp>
|
||||
|
||||
namespace SubgraphTestsDefinitions {
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
std::string MatMulConstTransposesExtractionTest::getTestCaseName(const testing::TestParamInfo<MatMulConstTransposesExtractionTestParams> &obj) {
|
||||
MatMulConstTransposesExtractionTestShapeParams shape_params;
|
||||
std::string device;
|
||||
std::tie(shape_params, std::ignore, device) = obj.param;
|
||||
std::ostringstream results;
|
||||
|
||||
results << "input=" << shape_params.input_shape << "_";
|
||||
results << "weights=" << shape_params.weights_shape << "_";
|
||||
results << "transB=" << std::boolalpha << shape_params.trans_b << "_";
|
||||
results << "dev=" << device;
|
||||
return results.str();
|
||||
}
|
||||
|
||||
void MatMulConstTransposesExtractionTest::SetUp() {
|
||||
MatMulConstTransposesExtractionTestShapeParams shape_params;
|
||||
element::Type type = element::f32;
|
||||
bool can_be_fused;
|
||||
std::tie(shape_params, can_be_fused, targetDevice) = GetParam();
|
||||
|
||||
const auto& input_shape = shape_params.input_shape;
|
||||
const auto& weights_shape = shape_params.weights_shape;
|
||||
|
||||
auto param = std::make_shared<opset8::Parameter>(type, input_shape);
|
||||
auto weights = opset8::Constant::create(type, weights_shape, {0.5});
|
||||
auto matmul = std::make_shared<opset8::MatMul>(param, weights, false, shape_params.trans_b);
|
||||
function = std::make_shared<Function>(matmul, ParameterVector{param});
|
||||
|
||||
auto transformed_function = clone_function(*function);
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::MatMulConstTransposesExtraction>();
|
||||
manager.run_passes(transformed_function);
|
||||
|
||||
bool functions_equal;
|
||||
auto orig_function = clone_function(*function);
|
||||
std::tie(functions_equal, std::ignore) = compare_functions(transformed_function, orig_function, true);
|
||||
if (can_be_fused) {
|
||||
ASSERT_FALSE(functions_equal);
|
||||
} else {
|
||||
ASSERT_TRUE(functions_equal);
|
||||
}
|
||||
}
|
||||
|
||||
std::string QuantizedMatMulConstTransposesExtractionTest::getTestCaseName(
|
||||
const testing::TestParamInfo<MatMulConstTransposesExtractionTestParams> &obj) {
|
||||
MatMulConstTransposesExtractionTestShapeParams params;
|
||||
std::string device;
|
||||
std::tie(params, std::ignore, device) = obj.param;
|
||||
std::ostringstream results;
|
||||
|
||||
results << "input=" << params.input_shape << "_"
|
||||
"weights=" << params.weights_shape << "_"
|
||||
"dev=" << device;
|
||||
return results.str();
|
||||
}
|
||||
|
||||
void QuantizedMatMulConstTransposesExtractionTest::SetUp() {
|
||||
MatMulConstTransposesExtractionTestShapeParams params;
|
||||
bool can_be_fused;
|
||||
std::tie(params, can_be_fused, targetDevice) = GetParam();
|
||||
|
||||
const auto& input_shape = params.input_shape;
|
||||
auto weights_shape = params.weights_shape;
|
||||
|
||||
element::Type type = element::f32;
|
||||
auto param = std::make_shared<opset8::Parameter>(type, input_shape);
|
||||
std::shared_ptr<Node> input;
|
||||
std::shared_ptr<Node> weights = opset8::Constant::create(type, weights_shape, {0.5});
|
||||
auto low = opset8::Constant::create(type, {1}, {-2});
|
||||
auto high = opset8::Constant::create(type, {1}, {2});
|
||||
input = std::make_shared<opset8::FakeQuantize>(param, low, high, low, high, 256);
|
||||
weights = std::make_shared<opset8::FakeQuantize>(weights, low, high, low, high, 255);
|
||||
auto matmul = std::make_shared<opset8::MatMul>(input, weights, false, false);
|
||||
function = std::make_shared<Function>(matmul, ParameterVector{param});
|
||||
|
||||
auto transformed_function = clone_function(*function);
|
||||
pass::Manager manager;
|
||||
manager.register_pass<pass::MatMulConstTransposesExtraction>();
|
||||
manager.run_passes(transformed_function);
|
||||
|
||||
bool functions_equal;
|
||||
auto orig_function = clone_function(*function);
|
||||
std::tie(functions_equal, std::ignore) = compare_functions(transformed_function, orig_function, true);
|
||||
if (can_be_fused) {
|
||||
ASSERT_FALSE(functions_equal);
|
||||
} else {
|
||||
ASSERT_TRUE(functions_equal);
|
||||
}
|
||||
}
|
||||
|
||||
void QuantizedMatMulConstTransposesExtractionTest::TearDown() {
|
||||
auto runtime_function = executableNetwork.GetExecGraphInfo().getFunction();
|
||||
int ops_found = 0;
|
||||
for (const auto& node : runtime_function->get_ordered_ops()) {
|
||||
const auto& layer_type = node->get_rt_info().at(ExecGraphInfoSerialization::LAYER_TYPE).as<std::string>();
|
||||
if (layer_type == "FullyConnected" || layer_type == "MatMul") {
|
||||
ops_found++;
|
||||
auto inputs = node->input_values();
|
||||
ASSERT_EQ(element::u8, inputs[0].get_element_type());
|
||||
ASSERT_EQ(element::i8, inputs[1].get_element_type());
|
||||
}
|
||||
}
|
||||
ASSERT_GT(ops_found, 0);
|
||||
}
|
||||
} // namespace SubgraphTestsDefinitions
|
Loading…
Reference in New Issue
Block a user