GNA add SwapInputMatMul transformation unit tests (#6458)

* - add SwapInputMatMul transformation unit tests
- fix SwapInputMatMul matcher pattern

* code review fixes: remove unused piece of code

* use clone function instead of creating reference function code duplicate
This commit is contained in:
Evgeny Kotov 2021-07-13 16:54:56 +03:00 committed by GitHub
parent b29ee60202
commit a3868698e6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 181 additions and 19 deletions

View File

@ -8,6 +8,7 @@
#include <vector>
#include <ngraph/pass/manager.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/rt_info.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
@ -22,8 +23,14 @@ NGRAPH_RTTI_DEFINITION(SwapInputMatMul, "SwapInputMatMul", 0);
SwapInputMatMul::SwapInputMatMul() {
MATCHER_SCOPE(SwapInputMatMul);
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({ngraph::pattern::any_input(
ngraph::pattern::has_static_shape()), ngraph::pattern::any_input(ngraph::pattern::has_static_shape())},
auto constant = ngraph::pattern::wrap_type<ngraph::opset7::Constant>({}, ngraph::pattern::rank_equals(2));
auto fake_quantize = ngraph::pattern::wrap_type<ngraph::opset7::FakeQuantize>({constant,
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>(),
ngraph::pattern::wrap_type<ngraph::opset7::Constant>()});
auto matmul_input = std::make_shared<ngraph::pattern::op::Or>(ngraph::OutputVector{constant, fake_quantize});
auto matmul = ngraph::pattern::wrap_type<ngraph::opset7::MatMul>({matmul_input, ngraph::pattern::any_input()},
ngraph::pattern::has_static_shape());
ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) {
auto matmul = std::dynamic_pointer_cast<ngraph::opset7::MatMul>(m.get_match_root());
@ -51,22 +58,6 @@ SwapInputMatMul::SwapInputMatMul() {
ngraph::NodeVector new_ops;
// Skip FakeQuantize layers
std::shared_ptr<ngraph::Node> input_a_skip_fq = input_a.get_node_shared_ptr();
if (std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(input_a_skip_fq)) {
input_a_skip_fq = input_a_skip_fq->input_value(0).get_node_shared_ptr();
}
std::shared_ptr<ngraph::Node> input_b_skip_fq = input_b.get_node_shared_ptr();
if (std::dynamic_pointer_cast<ngraph::opset7::FakeQuantize>(input_b_skip_fq)) {
input_b_skip_fq = input_b_skip_fq->input_value(0).get_node_shared_ptr();
}
if (!std::dynamic_pointer_cast<ngraph::opset7::Constant>(input_a_skip_fq) ||
std::dynamic_pointer_cast<ngraph::opset7::Constant>(input_b_skip_fq)) {
return false;
}
if (shape_input_a[0] < 8 || ((shape_input_a[0] % 8 != 0 || shape_input_a[1] % 8 != 0))) {
return false;
}

View File

@ -0,0 +1,171 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include "transformations/swap_input_matmul_gna.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include <ngraph/function.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/init_node_info.hpp>
namespace testing {
TEST(TransformationTests, SwapInputMatMulTestValidConstShape) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
const ngraph::Shape data_shape{8, 8};
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1, 8}, {1});
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(constant, input_params);
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
reference_func = ngraph::clone_function(*func);
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::SwapInputMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}
TEST(TransformationTests, SwapInputMatMulTest) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
const ngraph::Shape data_shape{8, 8};
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(constant, input_params);
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::SwapInputMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(input_params, constant, 1, 1);
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
std::vector<size_t>{1, 0});
auto transpose_operation = std::make_shared<ngraph::opset7::Transpose>(matmul_operation, transpose_order);
auto result = std::make_shared<ngraph::opset7::Result>(transpose_operation);
reference_func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
}
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}
TEST(TransformationTests, SwapInputMatMulTestFakeQuantize) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
const ngraph::Shape data_shape{8, 8};
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
auto fake_quantize_op = std::make_shared<ngraph::opset7::FakeQuantize>(constant, input_low,
input_high, output_low,
output_high, 11);
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(fake_quantize_op, input_params);
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::SwapInputMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1});
auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1});
auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20});
auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0});
auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10});
auto fake_quantize_op = std::make_shared<ngraph::opset7::FakeQuantize>(constant, input_low,
input_high, output_low,
output_high, 11);
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(input_params, fake_quantize_op, 1 , 1);
auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2},
std::vector<size_t>{1, 0});
auto transpose_operation = std::make_shared<ngraph::opset7::Transpose>(matmul_operation, transpose_order);
auto result = std::make_shared<ngraph::opset7::Result>(transpose_operation);
reference_func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
}
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}
TEST(TransformationTests, SwapInputMatMulTestRank1) {
std::shared_ptr<ngraph::Function> func(nullptr), reference_func(nullptr);
const ngraph::Shape data_shape{8, 8};
{
auto input_params = std::make_shared<ngraph::opset7::Parameter>(ngraph::element::i64, data_shape);
auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{8}, {1});
auto matmul_operation = std::make_shared<ngraph::opset7::MatMul>(constant, input_params);
auto result = std::make_shared<ngraph::opset7::Result>(matmul_operation);
func = std::make_shared<ngraph::Function>(ngraph::ResultVector{result},
ngraph::ParameterVector{input_params});
reference_func = ngraph::clone_function(*func);
ngraph::pass::Manager m;
m.register_pass<ngraph::pass::InitNodeInfo>();
m.register_pass<GNAPluginNS::SwapInputMatMul>();
m.run_passes(func);
ASSERT_NO_THROW(check_rt_info(func));
}
const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
const FunctionsComparator::Result result = func_comparator(func, reference_func);
ASSERT_TRUE(result.valid);
}
} // namespace testing