diff --git a/inference-engine/src/gna_plugin/gna_plugin.cpp b/inference-engine/src/gna_plugin/gna_plugin.cpp index 7ffab9442c3..9208fd55f53 100644 --- a/inference-engine/src/gna_plugin/gna_plugin.cpp +++ b/inference-engine/src/gna_plugin/gna_plugin.cpp @@ -685,6 +685,8 @@ void GNAPlugin::LoadNetwork(CNNNetwork & _network) { manager.register_pass(); manager.register_pass(); manager.register_pass(); + manager.register_pass(); + manager.register_pass(); manager.register_pass(); manager.register_pass(); manager.register_pass(); diff --git a/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp b/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp index 9a725c33cf7..888fa1ee4a8 100644 --- a/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp +++ b/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.cpp @@ -6,6 +6,7 @@ #include #include +#include #include #include #include @@ -17,84 +18,160 @@ using namespace GNAPluginNS; NGRAPH_RTTI_DEFINITION(SwapInputMatMul, "SwapInputMatMul", 0); +NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithBias, "SwapInputMatMulWithBias", 0); +NGRAPH_RTTI_DEFINITION(SwapInputMatMulWithFq, "SwapInputMatMulWithFq", 0); + +static void SwapAndTransposeInputs(std::shared_ptr matmul_node, + std::shared_ptr add, + std::shared_ptr bias, + std::shared_ptr fq) { + auto create_transpose = + [](ngraph::Output node, const std::string& transpose_name) -> std::shared_ptr { + ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape(); + + std::vector transpose_order(output_shape.size()); + std::iota(transpose_order.begin(), transpose_order.end(), 0); + std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2)); + + auto transpose = std::make_shared( + node, ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape {transpose_order.size()}, transpose_order)); + transpose->set_friendly_name(transpose_name); + return transpose; + }; + + ngraph::NodeVector new_ops; + + gnalog() << "Swap and transpose inputs for " << matmul_node->get_friendly_name() << "\n"; + std::shared_ptr new_matmul = std::make_shared( + matmul_node->input(1).get_source_output(), matmul_node->input(0).get_source_output(), + !matmul_node->get_transpose_b(), !matmul_node->get_transpose_a()); + new_matmul->set_friendly_name(matmul_node->get_friendly_name() + "/swap_inputs"); + new_ops.push_back(new_matmul); + + std::shared_ptr old_root_node = matmul_node; + if (bias != nullptr) { + // output of MatMul will be transposed comparing with original one, so the bias should be transposed too + if (bias->get_output_shape(0).size() > 1) { + bias = create_transpose(bias, bias->get_friendly_name() + "/transpose"); + new_ops.push_back(bias); + } + + new_matmul = std::make_shared(new_matmul, bias); + old_root_node = add; + new_ops.push_back(new_matmul); + } + + if (fq != nullptr) { + new_matmul = fq->clone_with_new_inputs({new_matmul, fq->input_value(1), fq->input_value(2), + fq->input_value(3), fq->input_value(4)}); + old_root_node = fq; + new_ops.push_back(new_matmul); + } + + auto output = create_transpose(new_matmul, matmul_node->get_friendly_name()); + new_ops.push_back(output); + + ngraph::copy_runtime_info(matmul_node, new_ops); + ngraph::replace_node(old_root_node, output); +} SwapInputMatMul::SwapInputMatMul() { - auto matmul = ngraph::pattern::wrap_type({ngraph::pattern::any_input( - ngraph::pattern::has_static_shape()), ngraph::pattern::any_input(ngraph::pattern::has_static_shape())}, - ngraph::pattern::has_static_shape()); - ngraph::matcher_pass_callback callback = [this](ngraph::pattern::Matcher& m) { - auto matmul = std::dynamic_pointer_cast(m.get_match_root()); - if (!matmul) { + auto constant = ngraph::pattern::wrap_type({}, [](const ngraph::Output& node) { + auto shape = node.get_node_shared_ptr()->get_output_shape(0); + if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) { return false; } - - auto input_a = matmul->input(0).get_source_output(); - auto input_b = matmul->input(1).get_source_output(); - - ngraph::Shape shape_input_a = input_a.get_shape(); - - auto create_transpose = [this](ngraph::Output node, const std::string& transpose_name) -> std::shared_ptr { - ngraph::Shape output_shape = node.get_node_shared_ptr()->get_shape(); - - std::vector transpose_order(output_shape.size()); - std::iota(transpose_order.begin(), transpose_order.end(), 0); - std::swap(*(transpose_order.end() - 1), *(transpose_order.end() - 2)); - - auto transpose = register_new_node( - node, ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape {transpose_order.size()}, transpose_order)); - transpose->set_friendly_name(transpose_name); - return transpose; - }; - - ngraph::NodeVector new_ops; - - // Skip FakeQuantize layers - std::shared_ptr input_a_skip_fq = input_a.get_node_shared_ptr(); - if (std::dynamic_pointer_cast(input_a_skip_fq)) { - input_a_skip_fq = input_a_skip_fq->input_value(0).get_node_shared_ptr(); - } - - std::shared_ptr input_b_skip_fq = input_b.get_node_shared_ptr(); - if (std::dynamic_pointer_cast(input_b_skip_fq)) { - input_b_skip_fq = input_b_skip_fq->input_value(0).get_node_shared_ptr(); - } - - if (!std::dynamic_pointer_cast(input_a_skip_fq) || - std::dynamic_pointer_cast(input_b_skip_fq)) { - return false; - } - - if (shape_input_a[0] < 8 || ((shape_input_a[0] % 8 != 0 || shape_input_a[1] % 8 != 0))) { - return false; - } - - gnalog() << "Swap and transpose inputs for " << matmul->get_friendly_name() << "\n"; - auto new_matmul = std::make_shared(input_b, input_a, !matmul->get_transpose_b(), !matmul->get_transpose_a()); - new_matmul->set_friendly_name(matmul->get_friendly_name() + "/swap_inputs"); - new_ops.push_back(new_matmul); - - if (!matmul->get_output_target_inputs(0).empty()) { - auto matmul_out = matmul->get_output_target_inputs(0).begin()->get_node()->shared_from_this(); - if (std::dynamic_pointer_cast(matmul_out) != nullptr) { - ngraph::copy_runtime_info(matmul, new_ops); - ngraph::replace_node(matmul, new_matmul); - auto consumers = matmul_out->output(0).get_target_inputs(); - auto traspose_output = create_transpose(matmul_out, matmul->get_friendly_name()); - for (auto input : consumers) { - input.replace_source_output(traspose_output); - } - return true; - } - } - - auto traspose_output = create_transpose(new_matmul, matmul->get_friendly_name()); - new_ops.push_back(traspose_output); - - ngraph::copy_runtime_info(matmul, new_ops); - ngraph::replace_node(matmul, traspose_output); + return true; + }); + auto fake_quantize = ngraph::pattern::wrap_type({constant, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + auto matmul_input = std::make_shared(ngraph::OutputVector{constant, fake_quantize}); + auto matmul = ngraph::pattern::wrap_type({matmul_input, ngraph::pattern::any_input()}, + ngraph::pattern::has_static_shape()); + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto matmul_node = std::dynamic_pointer_cast(pattern_map.at(matmul).get_node_shared_ptr()); + IE_ASSERT(matmul_node != nullptr); + SwapAndTransposeInputs(matmul_node, nullptr, nullptr, nullptr); return true; }; auto m = std::make_shared(matmul, "SwapInputMatMul"); this->register_matcher(m, callback); +} + +SwapInputMatMulWithBias::SwapInputMatMulWithBias() { + auto constant = ngraph::pattern::wrap_type({}, [](const ngraph::Output& node) { + auto shape = node.get_node_shared_ptr()->get_output_shape(0); + if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) { + return false; + } + return true; + }); + auto fake_quantize = ngraph::pattern::wrap_type({constant, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + auto matmul_input = std::make_shared(ngraph::OutputVector{constant, fake_quantize}); + auto matmul = ngraph::pattern::wrap_type({matmul_input, ngraph::pattern::any_input()}, + ngraph::pattern::has_static_shape()); + auto bias = ngraph::pattern::wrap_type(); + auto add = ngraph::pattern::wrap_type({matmul, bias}); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto matmul_node = std::dynamic_pointer_cast(pattern_map.at(matmul).get_node_shared_ptr()); + IE_ASSERT(matmul_node != nullptr); + SwapAndTransposeInputs(matmul_node, pattern_map.at(add).get_node_shared_ptr(), + pattern_map.at(bias).get_node_shared_ptr(), nullptr); + return true; + }; + + auto m = std::make_shared(add, "SwapInputMatMulWithBias"); + this->register_matcher(m, callback); +} + +SwapInputMatMulWithFq::SwapInputMatMulWithFq() { + auto constant = ngraph::pattern::wrap_type({}, [](const ngraph::Output& node) { + auto shape = node.get_node_shared_ptr()->get_output_shape(0); + if (shape.size() != 2 || shape[0] < 8 || ((shape[0] % 8 != 0 || shape[1] % 8 != 0))) { + return false; + } + return true; + }); + auto fake_quantize = ngraph::pattern::wrap_type({constant, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + auto matmul_input = std::make_shared(ngraph::OutputVector{constant, fake_quantize}); + auto matmul = ngraph::pattern::wrap_type({matmul_input, ngraph::pattern::any_input()}, + ngraph::pattern::has_static_shape()); + auto bias = ngraph::pattern::wrap_type(); + auto add = ngraph::pattern::wrap_type({matmul, bias}); + auto matmul_out = std::make_shared(ngraph::OutputVector{add, matmul}); + auto out_fq = ngraph::pattern::wrap_type({matmul_out, + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type(), + ngraph::pattern::wrap_type()}); + + ngraph::matcher_pass_callback callback = [=](ngraph::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + auto matmul_node = std::dynamic_pointer_cast(pattern_map.at(matmul).get_node_shared_ptr()); + IE_ASSERT(matmul_node != nullptr); + auto add_it = pattern_map.find(add); + auto add_node = (add_it == std::end(pattern_map) ? nullptr : add_it->second.get_node_shared_ptr()); + auto bias_it = pattern_map.find(bias); + auto bias_node = (bias_it == std::end(pattern_map) ? nullptr : bias_it->second.get_node_shared_ptr()); + SwapAndTransposeInputs(matmul_node, add_node, bias_node, pattern_map.at(out_fq).get_node_shared_ptr()); + return true; + }; + + auto m = std::make_shared(out_fq, "SwapInputMatMulWithFq"); + this->register_matcher(m, callback); } \ No newline at end of file diff --git a/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.hpp b/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.hpp index 66816868915..c9604f8b7c2 100644 --- a/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.hpp +++ b/inference-engine/src/gna_plugin/transformations/swap_input_matmul_gna.hpp @@ -1,4 +1,4 @@ -// Copyright (C) 2020 Intel Corporation +// Copyright (C) 2020-2021 Intel Corporation // SPDX-License-Identifier: Apache-2.0 // @@ -16,4 +16,16 @@ public: NGRAPH_RTTI_DECLARATION; SwapInputMatMul(); }; + +class SwapInputMatMulWithBias: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + SwapInputMatMulWithBias(); +}; + +class SwapInputMatMulWithFq: public ngraph::pass::MatcherPass { +public: + NGRAPH_RTTI_DECLARATION; + SwapInputMatMulWithFq(); +}; } // namespace GNAPluginNS \ No newline at end of file diff --git a/inference-engine/tests/unit/gna/ngraph/transformations/gna_swap_input_matmul.cpp b/inference-engine/tests/unit/gna/ngraph/transformations/gna_swap_input_matmul.cpp new file mode 100644 index 00000000000..b87f39ce265 --- /dev/null +++ b/inference-engine/tests/unit/gna/ngraph/transformations/gna_swap_input_matmul.cpp @@ -0,0 +1,172 @@ +// Copyright (C) 2021 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "transformations/swap_input_matmul_gna.hpp" + +#include "common_test_utils/ngraph_test_utils.hpp" +#include +#include +#include +#include + +namespace testing { + +static std::shared_ptr CreateMatMulFunction(const ngraph::Shape& input1_shape, + const ngraph::Shape& input2_shape, + const ngraph::Shape& bias_shape, + bool withBias, + bool withWeightsFq, + bool withOutFq, + bool swappedInputs) { + auto input_params = std::make_shared(ngraph::element::i64, input2_shape); + + auto constant = ngraph::opset7::Constant::create(ngraph::element::i64, input1_shape, {1}); + std::shared_ptr const_input = constant; + if (withWeightsFq) { + auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1}); + auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20}); + auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0}); + auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10}); + const_input = std::make_shared(const_input, input_low, input_high, + output_low, output_high, 11); + } + auto matmul = swappedInputs ? std::make_shared(input_params, const_input, true, true) : + std::make_shared(const_input, input_params); + + std::shared_ptr final_node = matmul; + if (withBias) { + auto bias = ngraph::opset7::Constant::create(ngraph::element::i64, bias_shape, {1}); + std::shared_ptr bias_node = bias; + if (swappedInputs && bias_shape.size() > 1) { + auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, + std::vector{1, 0}); + bias_node = std::make_shared(bias_node, transpose_order); + } + final_node = std::make_shared(matmul, bias_node); + } + + if (withOutFq) { + auto input_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {1}); + auto input_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {20}); + auto output_low = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {0}); + auto output_high = ngraph::opset7::Constant::create(ngraph::element::f32, ngraph::Shape{1}, {10}); + final_node = std::make_shared(final_node, input_low, input_high, + output_low, output_high, 11); + } + + if (swappedInputs) { + auto transpose_order = ngraph::opset7::Constant::create(ngraph::element::i64, ngraph::Shape{2}, + std::vector{1, 0}); + final_node = std::make_shared(final_node, transpose_order); + } + + auto result = std::make_shared(final_node); + return std::make_shared(ngraph::ResultVector{result}, + ngraph::ParameterVector{input_params}); +} + +static void Execute(std::shared_ptr function, std::shared_ptr reference_function) { + ngraph::pass::Manager m; + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.register_pass(); + m.run_passes(function); + ASSERT_NO_THROW(check_rt_info(function)); + + const FunctionsComparator func_comparator = FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES); + const FunctionsComparator::Result result = func_comparator(function, reference_function); + ASSERT_TRUE(result.valid); +} + +typedef std::tuple< + std::vector, // constant input shape, non-const input shape, bias shape + bool, // with bias + bool, // with weights FakeQuantize + bool // with output FakeQuantize +> SwapInputMatmulParams; + +static std::string getTestCaseName(testing::TestParamInfo obj) { + std::vector shapes; + bool withBias, withWeightsFq, withOutFq; + std::tie(shapes, withBias, withWeightsFq, withOutFq) = obj.param; + + std::ostringstream result; + result << "IS1=" << shapes[0] << "_"; + result << "IS2=" << shapes[1] << "_"; + result << "BS=" << shapes[2] << "_"; + result << "bias=" << withBias << "_"; + result << "wFQ=" << withWeightsFq << "_"; + result << "oFQ=" << withOutFq; + return result.str(); +} + +class SwapInputMatmul : public CommonTestUtils::TestsCommon, + public ::testing::WithParamInterface { +public: + void SetUp() override { + std::vector shapes; + bool withBias, withWeightsFq, withOutFq; + std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam(); + + function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false); + reference_function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, + withOutFq, true); + } +public: + std::shared_ptr function, reference_function; +}; + +class SwapInputMatmulNotApplied : public CommonTestUtils::TestsCommon, + public ::testing::WithParamInterface { +public: + void SetUp() override { + std::vector shapes; + bool withBias, withWeightsFq, withOutFq; + std::tie(shapes, withBias, withWeightsFq, withOutFq) = this->GetParam(); + + function = CreateMatMulFunction(shapes[0], shapes[1], shapes[2], withBias, withWeightsFq, withOutFq, false); + reference_function = ngraph::clone_function(*function); + } +public: + std::shared_ptr function, reference_function; +}; + +TEST_P(SwapInputMatmul, CompareFunctions) { + Execute(function, reference_function); +} + +TEST_P(SwapInputMatmulNotApplied, CompareFunctions) { + Execute(function, reference_function); +} + +const std::vector> input_shapes_applied = { + {{16, 8}, {8, 8}, {16, 8}}, + {{16, 8}, {8, 8}, {1}} +}; + +const std::vector> input_shapes_not_applied = { + {{1, 8}, {8, 8}, {1, 8}}, + {{8}, {8, 8}, {8}} +}; + +INSTANTIATE_TEST_CASE_P(smoke_swap_input_matmul, SwapInputMatmul, + ::testing::Combine( + ::testing::ValuesIn(input_shapes_applied), + ::testing::ValuesIn(std::vector{false, true}), + ::testing::ValuesIn(std::vector{false, true}), + ::testing::ValuesIn(std::vector{false, true})), + getTestCaseName); + +INSTANTIATE_TEST_CASE_P(smoke_swap_input_matmul, SwapInputMatmulNotApplied, + ::testing::Combine( + ::testing::ValuesIn(input_shapes_not_applied), + ::testing::ValuesIn(std::vector{false, true}), + ::testing::ValuesIn(std::vector{false, true}), + ::testing::ValuesIn(std::vector{false, true})), + getTestCaseName); + +} // namespace testing