[ nG transformation ] Const -> FQ -> Reshape fuse (#2388)
* [ nG transformation ] Const -> FQ -> Reshape fuse Ticket: 39124 * fix dtype incompatibility: uint64 vs size_t * Review comments adressed
This commit is contained in:
committed by
GitHub
parent
23373b5502
commit
97fad1cb35
@@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API FakeQuantizeReshapeFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief This transformation looks for a FQ + Reshape pair in the graph and moves
|
||||
* the Reshape operation above the FQ node. Shapes of limit inputs are updated
|
||||
* following FQ broadcasting semantics
|
||||
*/
|
||||
|
||||
class ngraph::pass::FakeQuantizeReshapeFusion : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
FakeQuantizeReshapeFusion();
|
||||
};
|
||||
@@ -0,0 +1,70 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/fq_reshape_fusion.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
ngraph::pass::FakeQuantizeReshapeFusion::FakeQuantizeReshapeFusion() {
|
||||
const auto fq_node_p = ngraph::pattern::wrap_type<opset4::FakeQuantize>(
|
||||
{ngraph::pattern::wrap_type<opset4::Constant>(), // for weights only
|
||||
ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input()},
|
||||
pattern::consumers_count(1));
|
||||
const auto reshape_node_p = ngraph::pattern::wrap_type<opset4::Reshape>(
|
||||
{fq_node_p, ngraph::pattern::any_input()});
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
|
||||
const auto &pattern_map = m.get_pattern_value_map();
|
||||
const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
|
||||
if (fq_node->is_dynamic())
|
||||
return false;
|
||||
const auto &reshape_node = pattern_map.at(reshape_node_p).get_node_shared_ptr();
|
||||
const auto &original_data_rank = fq_node->get_input_shape(0).size();
|
||||
OutputVector renewed_inputs = {reshape_node->clone_with_new_inputs({fq_node->input_value(0), reshape_node->input_value(1)})};
|
||||
for (auto i = 1; i < 5; ++i) {
|
||||
Output<Node> limit_input = fq_node->input_value(i);
|
||||
auto limit_shape = limit_input.get_shape();
|
||||
NGRAPH_CHECK(limit_shape.size() <= original_data_rank, "FakeQuantize limit input has unexpected rank");
|
||||
if (limit_shape.size() < original_data_rank) // aligning limit rank with data rank
|
||||
limit_shape.insert(limit_shape.begin(), original_data_rank - limit_shape.size(), uint64_t(1));
|
||||
NGRAPH_CHECK(limit_shape.size() == original_data_rank, "FakeQuantize limit input has unexpected rank");
|
||||
const auto &limit_size = shape_size(limit_shape);
|
||||
const auto &max_element = *std::max_element(limit_shape.begin(), limit_shape.end());
|
||||
if (max_element == limit_size) { // per-tensor / per-channel limit
|
||||
auto new_limit_shape = reshape_node->get_output_shape(0);
|
||||
std::transform(new_limit_shape.begin(), new_limit_shape.end(), new_limit_shape.begin(),
|
||||
[max_element](size_t &dim) { return dim == max_element ? max_element : 1; });
|
||||
const auto &new_limit_size = shape_size(new_limit_shape);
|
||||
if (new_limit_size == limit_size) { // we tracked future channel placement
|
||||
if (new_limit_shape == limit_input.get_shape())
|
||||
renewed_inputs.push_back(limit_input);
|
||||
else
|
||||
renewed_inputs.push_back(reshape_node->copy_with_new_inputs(
|
||||
{limit_input, opset4::Constant::create(element::i64, {new_limit_shape.size()}, new_limit_shape)}));
|
||||
continue;
|
||||
}
|
||||
}
|
||||
// resulting FQ will become or already is more than per-tensor / per-channel
|
||||
return false;
|
||||
}
|
||||
for (auto &new_input : renewed_inputs)
|
||||
copy_runtime_info({reshape_node, fq_node}, new_input.get_node_shared_ptr());
|
||||
const auto new_fq_node = fq_node->clone_with_new_inputs(renewed_inputs);
|
||||
replace_node(reshape_node, new_fq_node);
|
||||
new_fq_node->set_friendly_name(fq_node->get_friendly_name());
|
||||
copy_runtime_info({fq_node, reshape_node}, new_fq_node);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(reshape_node_p, "FakeQuantizeReshapeFusion");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
@@ -52,6 +52,7 @@
|
||||
#include <transformations/reduce_l1_decomposition.hpp>
|
||||
#include <transformations/reduce_l2_decomposition.hpp>
|
||||
#include <transformations/common_optimizations/fq_mul_fusion.hpp>
|
||||
#include <transformations/common_optimizations/fq_reshape_fusion.hpp>
|
||||
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
@@ -97,7 +98,6 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
decomp->add_matcher<ngraph::pass::BatchNormDecomposition>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMatMulToFC>();
|
||||
decomp->add_matcher<ngraph::pass::ConvertMatMulToGemm>();
|
||||
decomp->add_matcher<ngraph::pass::PullTransposeThroughFQUp>();
|
||||
decomp->set_name("ngraph::pass::Decompositions");
|
||||
|
||||
// CF is required after all decompositions
|
||||
@@ -112,9 +112,6 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
|
||||
// Multiply the thrird and fourth input instead of the output of FQ with all const inputs
|
||||
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
|
||||
// Convolution/Deconvolution/FullyConnected fusions
|
||||
auto convert_convolutions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
convert_convolutions->add_matcher<ngraph::pass::ConvertConvolution>();
|
||||
@@ -123,6 +120,12 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
convert_convolutions->add_matcher<ngraph::pass::ConvertGroupDeconvolution>();
|
||||
convert_convolutions->set_name("ngraph::pass::ConvertConvolutions");
|
||||
|
||||
auto fq_fusions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
fq_fusions->add_matcher<FakeQuantizeMulFusion>();
|
||||
fq_fusions->add_matcher<FakeQuantizeReshapeFusion>();
|
||||
fq_fusions->add_matcher<PullTransposeThroughFQUp>();
|
||||
fq_fusions->set_name("ngraph::pass::FakeQuantizeFusions");
|
||||
|
||||
// Convolution/Deconvolution/FullyConnected fusions
|
||||
auto fusion = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
fusion->add_matcher<ngraph::pass::ConvAddFusion>();
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
#include <map>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/function.hpp>
|
||||
#include <common_test_utils/ngraph_test_utils.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/fq_reshape_fusion.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "cnn_network_ngraph_impl.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace InferenceEngine;
|
||||
|
||||
namespace {
|
||||
|
||||
ngraph::Shape DO_NOT_RESHAPE = ngraph::Shape{0};
|
||||
|
||||
struct FQReshapeFusionTestCase {
|
||||
ngraph::Shape data_shape, il_shape, ih_shape, ol_shape, oh_shape;
|
||||
std::vector<int64_t> reshape_pattern;
|
||||
ngraph::Shape new_il_shape, new_ih_shape, new_ol_shape, new_oh_shape;
|
||||
bool is_negative;
|
||||
};
|
||||
|
||||
class nGraphFQReshapeFusionTests : public CommonTestUtils::TestsCommon, public testing::WithParamInterface<std::tuple<FQReshapeFusionTestCase>> {
|
||||
public:
|
||||
std::shared_ptr<ngraph::Function> f, ref_f;
|
||||
|
||||
void SetUp() override {
|
||||
const auto& parameters = GetParam();
|
||||
const auto& test_case = std::get<0>(GetParam());
|
||||
f = get_initial_function(test_case);
|
||||
if (test_case.is_negative)
|
||||
ref_f = get_initial_function(test_case);
|
||||
else
|
||||
ref_f = get_reference_function(test_case);
|
||||
}
|
||||
|
||||
private:
|
||||
std::shared_ptr<ngraph::Function> get_initial_function(const FQReshapeFusionTestCase & test_case) {
|
||||
const auto & data = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, test_case.data_shape, 0);
|
||||
auto il = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.il_shape);
|
||||
auto ih = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ih_shape);
|
||||
auto ol = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ol_shape);
|
||||
auto oh = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.oh_shape);
|
||||
|
||||
auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(data, il, ih, ol, oh, 42);
|
||||
|
||||
auto reshape_pattern = std::make_shared<ngraph::opset4::Constant>(
|
||||
ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern);
|
||||
auto reshape = std::make_shared<ngraph::opset4::Reshape>(fq, reshape_pattern, true);
|
||||
|
||||
auto result = std::make_shared<ngraph::op::Result>(reshape);
|
||||
ngraph::ParameterVector params = {il, ih, ol, oh};
|
||||
ngraph::ResultVector results = {result};
|
||||
return std::make_shared<ngraph::Function>(results, params);
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> get_reference_function(const FQReshapeFusionTestCase & test_case) {
|
||||
const auto & data = std::make_shared<ngraph::opset4::Constant>(ngraph::element::f32, test_case.data_shape, 0);
|
||||
const auto & reshaped_data = std::make_shared<ngraph::opset4::Reshape>(
|
||||
data,
|
||||
std::make_shared<ngraph::opset4::Constant>(ngraph::element::i64, ngraph::Shape{test_case.reshape_pattern.size()}, test_case.reshape_pattern),
|
||||
true);
|
||||
|
||||
const auto & p_il = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.il_shape);
|
||||
ngraph::Output<ngraph::Node> il = p_il;
|
||||
const auto & p_ih = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ih_shape);
|
||||
ngraph::Output<ngraph::Node> ih = p_ih;
|
||||
const auto & p_ol = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.ol_shape);
|
||||
ngraph::Output<ngraph::Node> ol = p_ol;
|
||||
const auto & p_oh = std::make_shared<ngraph::opset4::Parameter>(ngraph::element::f32, test_case.oh_shape);
|
||||
ngraph::Output<ngraph::Node> oh = p_oh;
|
||||
|
||||
if (test_case.new_il_shape != DO_NOT_RESHAPE)
|
||||
il = std::make_shared<ngraph::opset4::Reshape>(
|
||||
il, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_il_shape.size()}, test_case.new_il_shape), true);
|
||||
if (test_case.new_ih_shape != DO_NOT_RESHAPE)
|
||||
ih = std::make_shared<ngraph::opset4::Reshape>(
|
||||
ih, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_ih_shape.size()}, test_case.new_ih_shape), true);
|
||||
if (test_case.new_ol_shape != DO_NOT_RESHAPE)
|
||||
ol = std::make_shared<ngraph::opset4::Reshape>(
|
||||
ol, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_ol_shape.size()}, test_case.new_ol_shape), true);
|
||||
if (test_case.new_oh_shape != DO_NOT_RESHAPE)
|
||||
oh = std::make_shared<ngraph::opset4::Reshape>(
|
||||
oh, ngraph::opset4::Constant::create(ngraph::element::i64, {test_case.new_oh_shape.size()}, test_case.new_oh_shape), true);
|
||||
|
||||
auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(reshaped_data, il, ih, ol, oh, 42);
|
||||
|
||||
auto result = std::make_shared<ngraph::op::Result>(fq);
|
||||
ngraph::ParameterVector params = {p_il, p_ih, p_ol, p_oh};
|
||||
ngraph::ResultVector results = {result};
|
||||
return std::make_shared<ngraph::Function>(results, params);
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(nGraphFQReshapeFusionTests, ReshapeMatMul) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::FakeQuantizeReshapeFusion>();
|
||||
|
||||
manager.run_passes(f);
|
||||
ASSERT_NO_THROW(check_rt_info(f));
|
||||
auto res = compare_functions(f, ref_f);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(NGraph, nGraphFQReshapeFusionTests, testing::Values(
|
||||
// positive
|
||||
FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {2, 3}, {2, 1}, {1, 1}, DO_NOT_RESHAPE, {2, 1}, false},
|
||||
FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {1, 2, 1, 3}, {1, 2, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, DO_NOT_RESHAPE, false},
|
||||
FQReshapeFusionTestCase{{2, 3}, {2, 1}, {1}, {1, 1}, {1, 1}, {1, 2, 1, 3}, {1, 2, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, {1, 1, 1, 1}, false},
|
||||
// negative
|
||||
FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 3}, {1}, {1, 1}, {1, 2, 1, 1}, {1, 2, 1, 3}, {}, {}, {}, {}, true},
|
||||
FQReshapeFusionTestCase{{1, 2, 1, 3}, {2, 1, 1}, {1}, {1, 1}, {1, 2, 1, 1}, {6}, {}, {}, {}, {}, true}));
|
||||
} // namespace
|
||||
Reference in New Issue
Block a user