From a3868698e69ead8899339ed66bf6fe345ad01a02 Mon Sep 17 00:00:00 2001 From: Evgeny Kotov Date: Tue, 13 Jul 2021 16:54:56 +0300 Subject: [PATCH] GNA add SwapInputMatMul transformation unit tests (#6458) * - add SwapInputMatMul transformation unit tests - fix SwapInputMatMul matcher pattern * code review fixes: remove unused piece of code * use clone function instead of creating reference function code duplicate --- .../transformations/swap_input_matmul_gna.cpp | 29 +-- .../transformations/gna_swap_input_matmul.cpp | 171 ++++++++++++++++++ 2 files changed, 181 insertions(+), 19 deletions(-) create mode 100644 inference-engine/tests/unit/gna/ngraph/transformations/gna_swap_input_matmul.cpp diff --git a/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp b/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp index d177b83ba40..fdfcfc254d4 100644 --- a/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp +++ b/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include @@ -22,9 +23,15 @@ NGRAPH_RTTI_DEFINITION(SwapInputMatMul, "SwapInputMatMul", 0); SwapInputMatMul::SwapInputMatMul() { MATCHER_SCOPE(SwapInputMatMul); - auto matmul = ngraph::pattern::wrap_type({ngraph::pattern::any_input( - ngraph::pattern::has_static_shape()), ngraph::pattern::any_input(ngraph::pattern::has_static_shape())}, - ngraph::pattern::has_static_shape()); + auto constant = ngraph::pattern::wrap_type({}, ngraph::pattern::rank_equals(2)); + auto fake_quantize = ngraph::pattern::wrap_type({constant, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + auto matmul_input = std::make_shared(ngraph::OutputVector{constant, fake_quantize}); + auto matmul = ngraph::pattern::wrap_type({matmul_input, ngraph::pattern::any_input()}, + ngraph::pattern::has_static_shape()); ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) { auto matmul = std::dynamic_pointer_cast(m.get_match_root()); if (!matmul) { @@ -51,22 +58,6 @@ SwapInputMatMul::SwapInputMatMul() { ngraph::NodeVector new_ops; - // Skip FakeQuantize layers - std::shared_ptr input_a_skip_fq = input_a.get_node_shared_ptr(); - if (std::dynamic_pointer_cast(input_a_skip_fq)) { - input_a_skip_fq = input_a_skip_fq->input_value(0).get_node_shared_ptr(); - } - - std::shared_ptr input_b_skip_fq = input_b.get_node_shared_ptr(); - if (std::dynamic_pointer_cast(input_b_skip_fq)) { - input_b_skip_fq = input_b_skip_fq->input_value(0).get_node_shared_ptr(); - } - - if (!std::dynamic_pointer_cast(input_a_skip_fq) || - std::dynamic_pointer_cast(input_b_skip_fq)) { - return false; - } - if (shape_input_a[0] < 8 || ((shape_input_a[0] % 8 != 0 || shape_input_a[1] % 8 != 0))) { return false; } diff --git a/inference-engine/tests/unit/gna/ngraph/transformations/gna_swap_input_matmul.cpp b/inference-engine/tests/unit/gna/ngraph/transformations/gna_swap_input_matmul.cpp new file mode 100644 index 00000000000..2a80bb9f847 --- /dev/null +++ b/inference-engine/tests/unit/gna/ngraph/transformations/gna_swap_input_matmul.cpp @@ -0,0 +1,171 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "transformations/swap_input_matmul_gna.hpp" + +#include "common_test_utils/ngraph_test_utils.hpp" +#include +#include +#include +#include + +namespace testing { + +TEST(TransformationTests, SwapInputMatMulTestValidConstShape) { + std::shared_ptr func(nullptr), reference_func(nullptr); + const ngraph::Shape data_shape{8, 8}; + + { + auto input_params = std::make_shared(ngraph::element::i64, data_shape); + + auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{1, 8}, {1}); + auto matmul_operation = std::make_shared(constant, input_params); + + auto result = std::make_shared(matmul_operation); + func = std::make_shared(ngraph::ResultVector{result}, + ngraph::ParameterVector{input_params}); + + reference_func = ngraph::clone_function(*func); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +TEST(TransformationTests, SwapInputMatMulTest) { + std::shared_ptr func(nullptr), reference_func(nullptr); + const ngraph::Shape data_shape{8, 8}; + + { + auto input_params = std::make_shared(ngraph::element::i64, data_shape); + + auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1}); + auto matmul_operation = std::make_shared(constant, input_params); + + auto result = std::make_shared(matmul_operation); + func = std::make_shared(ngraph::ResultVector{result}, + ngraph::ParameterVector{input_params}); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + { + auto input_params = std::make_shared(ngraph::element::i64, data_shape); + + auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1}); + auto matmul_operation = std::make_shared(input_params, constant, 1, 1); + + auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, + std::vector{1, 0}); + auto transpose_operation = std::make_shared(matmul_operation, transpose_order); + + auto result = std::make_shared(transpose_operation); + reference_func = std::make_shared(ngraph::ResultVector{result}, + ngraph::ParameterVector{input_params}); + } + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +TEST(TransformationTests, SwapInputMatMulTestFakeQuantize) { + std::shared_ptr func(nullptr), reference_func(nullptr); + const ngraph::Shape data_shape{8, 8}; + + { + auto input_params = std::make_shared(ngraph::element::i64, data_shape); + + auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1}); + + auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1}); + auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20}); + auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0}); + auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10}); + auto fake_quantize_op = std::make_shared(constant, input_low, + input_high, output_low, + output_high, 11); + auto matmul_operation = std::make_shared(fake_quantize_op, input_params); + + auto result = std::make_shared(matmul_operation); + func = std::make_shared(ngraph::ResultVector{result}, + ngraph::ParameterVector{input_params}); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + { + auto input_params = std::make_shared(ngraph::element::i64, data_shape); + + auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{16, 8}, {1}); + + auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1}); + auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20}); + auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0}); + auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10}); + auto fake_quantize_op = std::make_shared(constant, input_low, + input_high, output_low, + output_high, 11); + auto matmul_operation = std::make_shared(input_params, fake_quantize_op, 1 , 1); + + auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, + std::vector{1, 0}); + auto transpose_operation = std::make_shared(matmul_operation, transpose_order); + + auto result = std::make_shared(transpose_operation); + reference_func = std::make_shared(ngraph::ResultVector{result}, + ngraph::ParameterVector{input_params}); + } + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +TEST(TransformationTests, SwapInputMatMulTestRank1) { + std::shared_ptr func(nullptr), reference_func(nullptr); + const ngraph::Shape data_shape{8, 8}; + + { + auto input_params = std::make_shared(ngraph::element::i64, data_shape); + + auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{8}, {1}); + auto matmul_operation = std::make_shared(constant, input_params); + + auto result = std::make_shared(matmul_operation); + func = std::make_shared(ngraph::ResultVector{result}, + ngraph::ParameterVector{input_params}); + + reference_func = ngraph::clone_function(*func); + + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.run_passes(func); + ASSERT_NO_THROW(check_rt_info(func)); + } + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(func, reference_func); + ASSERT_TRUE(result.valid); +} + +} // namespace testing