[ nG transformation ] Const -> FQ -> Reshape fuse (#2388)

* [ nG transformation ] Const -> FQ -> Reshape fuse
	Ticket: 39124

* fix dtype incompatibility: uint64 vs size_t

* Review comments adressed
This commit is contained in:
Evgenya Stepyreva
2020-09-24 11:44:08 +03:00
committed by GitHub
parent 23373b5502
commit 97fad1cb35
4 changed files with 234 additions and 4 deletions

View File

@@ -0,0 +1,32 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <vector>
#include <transformations_visibility.hpp>
#include <ngraph/pass/graph_rewrite.hpp>
namespace ngraph {
namespace pass {
class TRANSFORMATIONS_API FakeQuantizeReshapeFusion;
} // namespace pass
} // namespace ngraph
/**
* @ingroup ie_transformation_common_api
* @brief This transformation looks for a FQ + Reshape pair in the graph and moves
* the Reshape operation above the FQ node. Shapes of limit inputs are updated
* following FQ broadcasting semantics
*/
class ngraph::pass::FakeQuantizeReshapeFusion : public ngraph::pass::MatcherPass {
public:
FakeQuantizeReshapeFusion();
};

View File

@@ -0,0 +1,70 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "transformations/common_optimizations/fq_reshape_fusion.hpp"
#include <memory>
#include <vector>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include <ngraph/rt_info.hpp>
ngraph::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {
const auto fq_node_p = ngraph::pattern::wrap_type<opset4::FakeQuantize>(
{ngraph::pattern::wrap_type<opset4::Constant>(), // for weights only
ngraph::pattern::any_input(),
ngraph::pattern::any_input(),
ngraph::pattern::any_input(),
ngraph::pattern::any_input()},
pattern::consumers_count(1));
const auto reshape_node_p = ngraph::pattern::wrap_type<opset4::Reshape>(
{fq_node_p, ngraph::pattern::any_input()});
ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
const auto &pattern_map = m.get_pattern_value_map();
const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
if (fq_node->is_dynamic())
return false;
const auto &reshape_node = pattern_map.at(reshape_node_p).get_node_shared_ptr();
const auto &original_data_rank = fq_node->get_input_shape(0).size();
OutputVector renewed_inputs = {reshape_node->clone_with_new_inputs({fq_node->input_value(0), reshape_node->input_value(1)})};
for (auto i = 1; i < 5; ++i) {
Output<Node> limit_input = fq_node->input_value(i);
auto limit_shape = limit_input.get_shape();
NGRAPH_CHECK(limit_shape.size() <= original_data_rank, "FakeQuantize limit input has unexpected rank");
if (limit_shape.size() < original_data_rank) // aligning limit rank with data rank
limit_shape.insert(limit_shape.begin(), original_data_rank - limit_shape.size(), uint64_t(1));
NGRAPH_CHECK(limit_shape.size() == original_data_rank, "FakeQuantize limit input has unexpected rank");
const auto &limit_size = shape_size(limit_shape);
const auto &max_element = *std::max_element(limit_shape.begin(), limit_shape.end());
if (max_element == limit_size) { // per-tensor / per-channel limit
auto new_limit_shape = reshape_node->get_output_shape(0);
std::transform(new_limit_shape.begin(), new_limit_shape.end(), new_limit_shape.begin(),
[max_element](size_t &dim) { return dim == max_element ? max_element : 1; });
const auto &new_limit_size = shape_size(new_limit_shape);
if (new_limit_size == limit_size) { // we tracked future channel placement
if (new_limit_shape == limit_input.get_shape())
renewed_inputs.push_back(limit_input);
else
renewed_inputs.push_back(reshape_node->copy_with_new_inputs(
{limit_input, opset4::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)}));
continue;
}
}
// resulting FQ will become or already is more than per-tensor / per-channel
return false;
}
for (auto &new_input : renewed_inputs)
copy_runtime_info({reshape_node, fq_node}, new_input.get_node_shared_ptr());
const auto new_fq_node = fq_node->clone_with_new_inputs(renewed_inputs);
replace_node(reshape_node, new_fq_node);
new_fq_node->set_friendly_name(fq_node->get_friendly_name());
copy_runtime_info({fq_node, reshape_node}, new_fq_node);
return true;
};
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_node_p, "FakeQuantizeReshapeFusion");
this->register_matcher(m, callback);
}

View File

@@ -52,6 +52,7 @@
#include <transformations/reduce_l1_decomposition.hpp>
#include <transformations/reduce_l2_decomposition.hpp>
#include <transformations/common_optimizations/fq_mul_fusion.hpp>
#include <transformations/common_optimizations/fq_reshape_fusion.hpp>
#include <ngraph/pass/constant_folding.hpp>
#include <ngraph/pass/manager.hpp>
@@ -97,7 +98,6 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
decomp->add_matcher<ngraph::pass::ConvertMatMulToFC>();
decomp->add_matcher<ngraph::pass::ConvertMatMulToGemm>();
decomp->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
decomp->set_name("ngraph::pass::Decompositions");
// CF is required after all decompositions
@@ -112,9 +112,6 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
manager.register_pass<ngraph::pass::ConstantFolding>();
// Multiply the thrird and fourth input instead of the output of FQ with all const inputs
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
// Convolution/Deconvolution/FullyConnected fusions
auto convert_convolutions = manager.register_pass<ngraph::pass::GraphRewrite>();
convert_convolutions->add_matcher<ngraph::pass::ConvertConvolution>();
@@ -123,6 +120,12 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
convert_convolutions->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
convert_convolutions->set_name("ngraph::pass::ConvertConvolutions");
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
fq_fusions->add_matcher<FakeQuantizeMulFusion>();
fq_fusions->add_matcher<FakeQuantizeReshapeFusion>();
fq_fusions->add_matcher<PullTransposeThroughFQUp>();
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
// Convolution/Deconvolution/FullyConnected fusions
auto fusion = manager.register_pass<ngraph::pass::GraphRewrite>();
fusion->add_matcher<ngraph::pass::ConvAddFusion>();

View File

@@ -0,0 +1,125 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <string>
#include <memory>
#include <map>
#include <ngraph/opsets/opset4.hpp>
#include <ngraph/function.hpp>
#include <common_test_utils/ngraph_test_utils.hpp>
#include <ngraph/pass/manager.hpp>
#include <transformations/common_optimizations/fq_reshape_fusion.hpp>
#include <transformations/init_node_info.hpp>
#include "cnn_network_ngraph_impl.hpp"
using namespace testing;
using namespace InferenceEngine;
namespace {
ngraph::Shape DO_NOT_RESHAPE = ngraph::Shape{0};
struct FQReshapeFusionTestCase {
ngraph::Shape data_shape, il_shape, ih_shape, ol_shape, oh_shape;
std::vector<int64_t> reshape_pattern;
ngraph::Shape new_il_shape, new_ih_shape, new_ol_shape, new_oh_shape;
bool is_negative;
};
class nGraphFQReshapeFusionTests : public CommonTestUtils::TestsCommon, public testing::WithParamInterface<std::tuple<FQReshapeFusionTestCase>> {
public:
std::shared_ptr<ngraph::Function> f, ref_f;
void SetUp() override {
const auto& parameters = GetParam();
const auto& test_case = std::get<0>(GetParam());
f = get_initial_function(test_case);
if (test_case.is_negative)
ref_f = get_initial_function(test_case);
else
ref_f = get_reference_function(test_case);
}
private:
std::shared_ptr<ngraph::Function> get_initial_function(const FQReshapeFusionTestCase & test_case) {
const auto & data = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, test_case.data_shape, 0);
auto il = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.il_shape);
auto ih = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ih_shape);
auto ol = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ol_shape);
auto oh = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.oh_shape);
auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(data, il, ih, ol, oh, 42);
auto reshape_pattern = std::make_shared<ngraph::opset4::Constant>(
ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern);
auto reshape = std::make_shared<ngraph::opset4::Reshape>(fq, reshape_pattern, true);
auto result = std::make_shared<ngraph::op::Result>(reshape);
ngraph::ParameterVector params = {il, ih, ol, oh};
ngraph::ResultVector results = {result};
return std::make_shared<ngraph::Function>(results, params);
}
std::shared_ptr<ngraph::Function> get_reference_function(const FQReshapeFusionTestCase & test_case) {
const auto & data = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, test_case.data_shape, 0);
const auto & reshaped_data = std::make_shared<ngraph::opset4::Reshape>(
data,
std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern),
true);
const auto & p_il = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.il_shape);
ngraph::Output<ngraph::Node> il = p_il;
const auto & p_ih = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ih_shape);
ngraph::Output<ngraph::Node> ih = p_ih;
const auto & p_ol = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ol_shape);
ngraph::Output<ngraph::Node> ol = p_ol;
const auto & p_oh = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.oh_shape);
ngraph::Output<ngraph::Node> oh = p_oh;
if (test_case.new_il_shape != DO_NOT_RESHAPE)
il = std::make_shared<ngraph::opset4::Reshape>(
il, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_il_shape.size()}, test_case.new_il_shape), true);
if (test_case.new_ih_shape != DO_NOT_RESHAPE)
ih = std::make_shared<ngraph::opset4::Reshape>(
ih, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_ih_shape.size()}, test_case.new_ih_shape), true);
if (test_case.new_ol_shape != DO_NOT_RESHAPE)
ol = std::make_shared<ngraph::opset4::Reshape>(
ol, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_ol_shape.size()}, test_case.new_ol_shape), true);
if (test_case.new_oh_shape != DO_NOT_RESHAPE)
oh = std::make_shared<ngraph::opset4::Reshape>(
oh, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_oh_shape.size()}, test_case.new_oh_shape), true);
auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(reshaped_data, il, ih, ol, oh, 42);
auto result = std::make_shared<ngraph::op::Result>(fq);
ngraph::ParameterVector params = {p_il, p_ih, p_ol, p_oh};
ngraph::ResultVector results = {result};
return std::make_shared<ngraph::Function>(results, params);
}
};
TEST_P(nGraphFQReshapeFusionTests, ReshapeMatMul) {
ngraph::pass::Manager manager;
manager.register_pass<ngraph::pass::InitNodeInfo>();
manager.register_pass<ngraph::pass::FakeQuantizeReshapeFusion>();
manager.run_passes(f);
ASSERT_NO_THROW(check_rt_info(f));
auto res = compare_functions(f, ref_f);
ASSERT_TRUE(res.first) << res.second;
}
INSTANTIATE_TEST_CASE_P(NGraph, nGraphFQReshapeFusionTests, testing::Values(
// positive
FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {2, 3}, {2, 1}, {1, 1}, DO_NOT_RESHAPE, {2, 1}, false},
FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {1, 2, 1, 3}, {1, 2, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, DO_NOT_RESHAPE, false},
FQReshapeFusionTestCase{{2, 3}, {2, 1}, {1}, {1, 1}, {1, 1}, {1, 2, 1, 3}, {1, 2, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, false},
// negative
FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 3}, {1}, {1, 1}, {1, 2, 1, 1}, {1, 2, 1, 3}, {}, {}, {}, {}, true},
FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {6}, {}, {}, {}, {}, true}));
} // namespace