GNA add InsertTransposeBeforeMatmul unit tests (#6421)

- create dir inference-engine/src/gna_plugin/transformations
- add tests insert_transpose_before_matmul.cpp
This commit is contained in:
Evgeny Kotov 2021-06-29 18:08:46 +03:00 committed by GitHub
parent 1ca53717fb
commit 124f438b4a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 140 additions and 0 deletions

View File

@ -9,6 +9,7 @@
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/rt_info.hpp>
using namespace GNAPluginNS;
@ -59,6 +60,7 @@ InsertTransposeBeforeMatmul::InsertTransposeBeforeMatmul() {
input.replace_source_output(reshapeAfter);
}
ngraph::copy_runtime_info(matmul_node, {transpose, reshapeAfter});
return true;
};

View File

@ -0,0 +1,138 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <tuple>
#include "transformations/insert_transpose_before_matmul.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
namespace testing {
namespace {
std::shared_ptr<ngraph::Function> createFunction(const ngraph::PartialShape& input_values,
const ngraph::Shape& reshape_values,
const ngraph::Shape& matmul_values) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_values);
auto new_shape = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{reshape_values.size()}, reshape_values);
auto reshape_operation = std::make_shared<ngraph::opset7::Reshape>(input_params, new_shape, true);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{matmul_values.size()}, matmul_values);
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(reshape_operation, constant);
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
}
// ---------------------------------------------------------------------------------------------------------------------
class InsertTransposeBeforeMatmulTestInvalidFixture: public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<std::tuple<ngraph::PartialShape, ngraph::Shape, ngraph::Shape>> {
public:
void SetUp() override;
public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
void InsertTransposeBeforeMatmulTestInvalidFixture::SetUp() {
ngraph::PartialShape input_shape;
ngraph::Shape reshape_shape, matmul_shape;
std::tie(input_shape, reshape_shape, matmul_shape) = this->GetParam();
function = createFunction(input_shape, reshape_shape, matmul_shape);
reference_function = createFunction(input_shape, reshape_shape, matmul_shape);
}
// ---------------------------------------------------------------------------------------------------------------------
class InsertTransposeBeforeMatmulTestFixture: public CommonTestUtils::TestsCommon,
public ::testing::WithParamInterface<std::tuple<ngraph::PartialShape, ngraph::Shape, ngraph::Shape>> {
public:
void SetUp() override;
std::shared_ptr<ngraph::Function> get_initial_function(const ngraph::PartialShape & input_shape,
const ngraph::Shape & reshape_shape,
const ngraph::Shape & matmul_shape);
std::shared_ptr<ngraph::Function> get_reference(const ngraph::PartialShape & input_shape);
public:
std::shared_ptr<ngraph::Function> function, reference_function;
};
void InsertTransposeBeforeMatmulTestFixture::SetUp() {
ngraph::PartialShape input_shape;
ngraph::Shape reshape_shape, matmul_shape;
std::tie(input_shape, reshape_shape, matmul_shape) = this->GetParam();
function = get_initial_function(input_shape, reshape_shape, matmul_shape);
reference_function = get_reference(input_shape);
}
std::shared_ptr<ngraph::Function> InsertTransposeBeforeMatmulTestFixture::get_initial_function(const ngraph::PartialShape & input_shape,
const ngraph::Shape & reshape_shape,
const ngraph::Shape & matmul_shape) {
return createFunction(input_shape, reshape_shape, matmul_shape);
}
std::shared_ptr<ngraph::Function> InsertTransposeBeforeMatmulTestFixture::get_reference(const ngraph::PartialShape & input_shape) {
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, input_shape);
auto new_shape = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {8, 2});
auto reshape_operation = std::make_shared<ngraph::opset7::Reshape>(input_params, new_shape, true);
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
std::vector<size_t>{1, 0});
auto transpose_operation = std::make_shared<ngraph::opset7::Transpose>(reshape_operation, transpose_order);
auto new_shape_after_transpose = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {8, 2});
auto reshape_after_transpose = std::make_shared<ngraph::opset7::Reshape>(transpose_operation,
new_shape_after_transpose,
false);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, {2, 1});
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(reshape_after_transpose, constant);
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
return std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
}
// ---------------------------------------------------------------------------------------------------------------------
void execute_test(std::shared_ptr<ngraph::Function> function, std::shared_ptr<ngraph::Function> reference_function) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<GNAPluginNS::InsertTransposeBeforeMatmul>();
manager.run_passes(function);
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(function, reference_function);
ASSERT_TRUE(result.valid);
}
TEST_P(InsertTransposeBeforeMatmulTestFixture, CompareFunctions) {
execute_test(function, reference_function);
}
INSTANTIATE_TEST_SUITE_P(InsertTransposeBeforeMatmulTestSuite, InsertTransposeBeforeMatmulTestFixture,
::testing::Values(std::make_tuple(ngraph::PartialShape{2, 8}, ngraph::Shape{8, 2}, ngraph::Shape{2, 1}),
std::make_tuple(ngraph::PartialShape{1, 16}, ngraph::Shape{8, 2}, ngraph::Shape{2, 1})));
TEST_P(InsertTransposeBeforeMatmulTestInvalidFixture, CompareFunctions) {
execute_test(function, reference_function);
}
INSTANTIATE_TEST_SUITE_P(InsertTransposeBeforeMatmulTestInvalidSuite, InsertTransposeBeforeMatmulTestInvalidFixture,
::testing::Values(std::make_tuple(ngraph::PartialShape{2, 9}, ngraph::Shape{9, 2}, ngraph::Shape{2, 1}),
std::make_tuple(ngraph::PartialShape{9, 2}, ngraph::Shape{9, 2}, ngraph::Shape{2, 1})));
} // namespace
} // namespace testing