Backport of FQ+Mul transform to master (#2214)
* Backport of FQ+Mul transform to master * Accept any type of input to FQ in the transformation * Test the fusion when all FQ inputs are non-const * Fusion test when only one output limit is const * Test passing the output of FQ to second input of Mul
This commit is contained in:
parent
c13ec24e1e
commit
dda6d9136b
@ -0,0 +1,32 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <transformations_visibility.hpp>
|
||||||
|
|
||||||
|
#include <ngraph/pass/graph_rewrite.hpp>
|
||||||
|
|
||||||
|
namespace ngraph {
|
||||||
|
namespace pass {
|
||||||
|
|
||||||
|
class TRANSFORMATIONS_API FakeQuantizeMulFusion;
|
||||||
|
|
||||||
|
} // namespace pass
|
||||||
|
} // namespace ngraph
|
||||||
|
|
||||||
|
/**
|
||||||
|
* @ingroup ie_transformation_common_api
|
||||||
|
* @brief This transformation looks for a FQ + Mul pair in the graph and moves
|
||||||
|
* the Mul operation above the FQ node. The last two inputs of FQ are multiplied
|
||||||
|
* by the value that was originally below the FQ node.
|
||||||
|
*/
|
||||||
|
|
||||||
|
class ngraph::pass::FakeQuantizeMulFusion : public ngraph::pass::MatcherPass {
|
||||||
|
public:
|
||||||
|
FakeQuantizeMulFusion();
|
||||||
|
};
|
@ -0,0 +1,108 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include "transformations/common_optimizations/fq_mul_fusion.hpp"
|
||||||
|
#include "transformations/utils/utils.hpp"
|
||||||
|
|
||||||
|
#include <memory>
|
||||||
|
#include <vector>
|
||||||
|
|
||||||
|
#include <ngraph/opsets/opset4.hpp>
|
||||||
|
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||||
|
#include <ngraph/rt_info.hpp>
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
std::pair<ngraph::Output<ngraph::Node>, ngraph::Output<ngraph::Node>>
|
||||||
|
get_adjusted_output_range(ngraph::Output<ngraph::Node> out_low,
|
||||||
|
ngraph::Output<ngraph::Node> out_high,
|
||||||
|
ngraph::Output<ngraph::Node> multiplier) {
|
||||||
|
const auto mul_out_low = std::make_shared<ngraph::opset4::Multiply>(out_low, multiplier);
|
||||||
|
const auto mul_out_high = std::make_shared<ngraph::opset4::Multiply>(out_high, multiplier);
|
||||||
|
copy_runtime_info({out_low.get_node_shared_ptr(), multiplier.get_node_shared_ptr()},
|
||||||
|
mul_out_low);
|
||||||
|
copy_runtime_info({out_high.get_node_shared_ptr(), multiplier.get_node_shared_ptr()},
|
||||||
|
mul_out_high);
|
||||||
|
|
||||||
|
ngraph::OutputVector new_out_low(1), new_out_high(1);
|
||||||
|
|
||||||
|
if (!mul_out_low->constant_fold(new_out_low, {out_low, multiplier})) {
|
||||||
|
new_out_low[0] = mul_out_low;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!mul_out_high->constant_fold(new_out_high, {out_high, multiplier})) {
|
||||||
|
new_out_high[0] = mul_out_high;
|
||||||
|
}
|
||||||
|
|
||||||
|
return {new_out_low[0], new_out_high[0]};
|
||||||
|
}
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
// This transformation multiplies the "output_low" and "output_high" inputs of the FQ operation
|
||||||
|
// by the constant value that before transormation is used to multiply the output of FQ.
|
||||||
|
// Both output_low and output_high are multiplied by the value represented as C (a constant) below.
|
||||||
|
// In case any of the FQ inputs (out_L, out_H) is constant, it gets constant folded with C.
|
||||||
|
//
|
||||||
|
// data in_L in_H out_L out_H
|
||||||
|
// | | | | |
|
||||||
|
// | | | | | data in_L in_H out_L * C out_H * C
|
||||||
|
// v v v v v | | | | |
|
||||||
|
// +-------------------------+ | | | | |
|
||||||
|
// | FakeQuantize | v v v v v
|
||||||
|
// +-------------------------+ +-----------------------------------+
|
||||||
|
// | =====> | FakeQuantize |
|
||||||
|
// v +-----------------------------------+
|
||||||
|
// +----------+ |
|
||||||
|
// | Multiply | <--- C v
|
||||||
|
// +----+-----+
|
||||||
|
// |
|
||||||
|
// v
|
||||||
|
//
|
||||||
|
|
||||||
|
ngraph::pass::FakeQuantizeMulFusion::FakeQuantizeMulFusion() {
|
||||||
|
const auto fq_output_low_p = ngraph::pattern::any_input();
|
||||||
|
const auto fq_output_high_p = ngraph::pattern::any_input();
|
||||||
|
|
||||||
|
const auto fq_node_p = ngraph::pattern::wrap_type<opset4::FakeQuantize>(
|
||||||
|
{ngraph::pattern::any_input(),
|
||||||
|
ngraph::pattern::any_input(),
|
||||||
|
ngraph::pattern::any_input(),
|
||||||
|
fq_output_low_p,
|
||||||
|
fq_output_high_p},
|
||||||
|
pattern::consumers_count(1));
|
||||||
|
|
||||||
|
const auto mul_constant_p = ngraph::pattern::wrap_type<opset4::Constant>();
|
||||||
|
const auto mul_node_p = ngraph::pattern::wrap_type<opset4::Multiply>(
|
||||||
|
{fq_node_p, mul_constant_p}, pattern::consumers_count(1));
|
||||||
|
|
||||||
|
ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
|
||||||
|
const auto& pattern_map = m.get_pattern_value_map();
|
||||||
|
|
||||||
|
const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
|
||||||
|
|
||||||
|
const auto original_output_low = pattern_map.at(fq_output_low_p);
|
||||||
|
const auto original_output_high = pattern_map.at(fq_output_high_p);
|
||||||
|
const auto mul_constant = pattern_map.at(mul_constant_p);
|
||||||
|
|
||||||
|
const auto new_output_limits = get_adjusted_output_range(
|
||||||
|
original_output_low, original_output_high, mul_constant);
|
||||||
|
|
||||||
|
const auto new_fq_node = fq_node->clone_with_new_inputs({fq_node->input_value(0),
|
||||||
|
fq_node->input_value(1),
|
||||||
|
fq_node->input_value(2),
|
||||||
|
new_output_limits.first,
|
||||||
|
new_output_limits.second});
|
||||||
|
|
||||||
|
const auto mul_node = pattern_map.at(mul_node_p).get_node_shared_ptr();
|
||||||
|
replace_node(mul_node, new_fq_node);
|
||||||
|
|
||||||
|
new_fq_node->set_friendly_name(fq_node->get_friendly_name());
|
||||||
|
copy_runtime_info({fq_node, mul_node}, new_fq_node);
|
||||||
|
|
||||||
|
return true;
|
||||||
|
};
|
||||||
|
|
||||||
|
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_node_p,
|
||||||
|
"FakeQuantizeMulFusion");
|
||||||
|
this->register_matcher(m, callback);
|
||||||
|
}
|
@ -51,6 +51,7 @@
|
|||||||
#include <transformations/hswish_decomposition.hpp>
|
#include <transformations/hswish_decomposition.hpp>
|
||||||
#include <transformations/reduce_l1_decomposition.hpp>
|
#include <transformations/reduce_l1_decomposition.hpp>
|
||||||
#include <transformations/reduce_l2_decomposition.hpp>
|
#include <transformations/reduce_l2_decomposition.hpp>
|
||||||
|
#include <transformations/common_optimizations/fq_mul_fusion.hpp>
|
||||||
|
|
||||||
#include <ngraph/pass/constant_folding.hpp>
|
#include <ngraph/pass/constant_folding.hpp>
|
||||||
#include <ngraph/pass/manager.hpp>
|
#include <ngraph/pass/manager.hpp>
|
||||||
@ -111,6 +112,9 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
|||||||
manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
|
manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
|
||||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||||
|
|
||||||
|
// Multiply the thrird and fourth input instead of the output of FQ with all const inputs
|
||||||
|
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||||
|
|
||||||
// Convolution/Deconvolution/FullyConnected fusions
|
// Convolution/Deconvolution/FullyConnected fusions
|
||||||
auto convert_convolutions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
auto convert_convolutions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||||
convert_convolutions->add_matcher<ngraph::pass::ConvertConvolution>();
|
convert_convolutions->add_matcher<ngraph::pass::ConvertConvolution>();
|
||||||
|
@ -0,0 +1,362 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#include <gtest/gtest.h>
|
||||||
|
#include <memory>
|
||||||
|
#include <tuple>
|
||||||
|
|
||||||
|
#include <ie_core.hpp>
|
||||||
|
|
||||||
|
#include "common_test_utils/common_utils.hpp"
|
||||||
|
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||||
|
#include "common_test_utils/test_common.hpp"
|
||||||
|
#include "functional_test_utils/plugin_cache.hpp"
|
||||||
|
|
||||||
|
#include <ngraph/function.hpp>
|
||||||
|
#include <ngraph/opsets/opset4.hpp>
|
||||||
|
#include <ngraph/pass/manager.hpp>
|
||||||
|
#include <transformations/common_optimizations/fq_mul_fusion.hpp>
|
||||||
|
#include <transformations/init_node_info.hpp>
|
||||||
|
|
||||||
|
namespace LayerTestsDefinitions {
|
||||||
|
|
||||||
|
using FQMulFusionParams =
|
||||||
|
std::tuple<ngraph::Shape, // FQ data shape
|
||||||
|
ngraph::Shape, // in_* shape
|
||||||
|
ngraph::Shape, // out_* shape
|
||||||
|
ngraph::Shape, // Mul constant shape
|
||||||
|
ngraph::Shape>; // Expected shape of the new out_* constants
|
||||||
|
|
||||||
|
class FQMulFusion : public testing::WithParamInterface<FQMulFusionParams>,
|
||||||
|
public CommonTestUtils::TestsCommon {
|
||||||
|
public:
|
||||||
|
void SetUp() override {
|
||||||
|
ngraph::Shape data_shape, in_shape, out_shape, mul_const_shape, expected_out_shape;
|
||||||
|
std::tie(data_shape, in_shape, out_shape, mul_const_shape, expected_out_shape) =
|
||||||
|
this->GetParam();
|
||||||
|
|
||||||
|
const auto data = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, data_shape, {0.0f});
|
||||||
|
const auto in_low = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, in_shape, {-0.5f});
|
||||||
|
const auto in_high = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, in_shape, {0.5f});
|
||||||
|
const auto out_low = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, out_shape, {0.0f});
|
||||||
|
const auto out_high = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, out_shape, {100.0f});
|
||||||
|
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||||
|
data, in_low, in_high, out_low, out_high, 255);
|
||||||
|
|
||||||
|
const auto mul_value = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, mul_const_shape, {3.14f});
|
||||||
|
const auto mul = std::make_shared<ngraph::opset4::Multiply>(fq, mul_value);
|
||||||
|
|
||||||
|
m_function = std::make_shared<ngraph::Function>(
|
||||||
|
ngraph::OutputVector{mul}, ngraph::ParameterVector{}, "FQMulFusion");
|
||||||
|
|
||||||
|
const auto expected_data = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, data_shape, {0.0f});
|
||||||
|
const auto expected_in_low = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, in_shape, {-0.5f});
|
||||||
|
const auto expected_in_high = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, in_shape, {0.5f});
|
||||||
|
const auto expected_out_low = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, expected_out_shape, {0.0f});
|
||||||
|
const auto expected_out_high = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, expected_out_shape, {314.0f});
|
||||||
|
|
||||||
|
const auto expected_fq =
|
||||||
|
std::make_shared<ngraph::opset4::FakeQuantize>(expected_data,
|
||||||
|
expected_in_low, expected_in_high, expected_out_low, expected_out_high, 255);
|
||||||
|
|
||||||
|
m_expected_function = std::make_shared<ngraph::Function>(
|
||||||
|
ngraph::OutputVector{expected_fq}, ngraph::ParameterVector{}, "FQMulFusion_expected");
|
||||||
|
}
|
||||||
|
|
||||||
|
std::shared_ptr<ngraph::Function> m_function;
|
||||||
|
std::shared_ptr<ngraph::Function> m_expected_function;
|
||||||
|
};
|
||||||
|
|
||||||
|
TEST_P(FQMulFusion, ExpectFusion) {
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||||
|
|
||||||
|
manager.run_passes(m_function);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(m_function));
|
||||||
|
|
||||||
|
const auto res = compare_functions(m_function, m_expected_function);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
};
|
||||||
|
|
||||||
|
namespace {
|
||||||
|
INSTANTIATE_TEST_CASE_P(ScalarFQParams_C6_4D_channel_0, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{64, 3, 7, 7}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{64, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{64, 1, 1, 1})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(ScalarFQParams_C6_4D_channel_1, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{64, 3, 7, 7}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 3, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 3, 1, 1})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(ScalarFQParams_C6_scalar, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{64, 3, 7, 7}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(FQOutputs1D_C6_scalar, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{64, 3, 7, 7}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{1}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{1})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(FQOutputs_NHWC_C6_scalar, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{1, 7, 7, 3}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 3}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 3})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(FQOutputs_NCHW_C6_scalar, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{1, 3, 7, 7}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 3, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 3, 1, 1})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(FQInputs_4D_with_channel_dimension, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 1, 1})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(FQInputs_4D_per__multiplier_with_channel, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 1, 1})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(FQInputs_4D_with_channel__multiplier_4D_per_tensor, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 1, 1})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(FQInputs_4D__multiplier_channel_3rd_dim, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 3, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 3, 1})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(FQOutputs_1D__multiplier_3D, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 3, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 3, 1})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(FQ_all_ones__multiplier_4D_with_channel, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 1, 1})));
|
||||||
|
|
||||||
|
INSTANTIATE_TEST_CASE_P(FQInOUt_ones__multiplier_4D_with_channel, FQMulFusion,
|
||||||
|
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||||
|
::testing::Values(ngraph::Shape{1, 64, 3, 3})));
|
||||||
|
|
||||||
|
TEST(FQMulFusion_NonConstInputs, AllInputsNonConst) {
|
||||||
|
const auto data = std::make_shared<ngraph::opset4::Parameter>(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{1, 3, 224, 224});
|
||||||
|
const auto in_low =
|
||||||
|
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||||
|
const auto in_high =
|
||||||
|
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||||
|
const auto out_low =
|
||||||
|
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||||
|
const auto out_high =
|
||||||
|
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||||
|
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||||
|
data, in_low, in_high, out_low, out_high, 42);
|
||||||
|
|
||||||
|
const auto mul_value = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {3.14f});
|
||||||
|
const auto mul = std::make_shared<ngraph::opset4::Multiply>(fq, mul_value);
|
||||||
|
|
||||||
|
auto function = std::make_shared<ngraph::Function>(ngraph::OutputVector{mul},
|
||||||
|
ngraph::ParameterVector{data, in_low, in_high, out_low, out_high});
|
||||||
|
|
||||||
|
const auto expected_out_low = std::make_shared<ngraph::opset4::Multiply>(out_low, mul_value);
|
||||||
|
const auto expected_out_high = std::make_shared<ngraph::opset4::Multiply>(out_high, mul_value);
|
||||||
|
|
||||||
|
const auto expected_fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||||
|
data, in_low, in_high, expected_out_low, expected_out_high, 42);
|
||||||
|
|
||||||
|
const auto expected_function =
|
||||||
|
std::make_shared<ngraph::Function>(ngraph::OutputVector{expected_fq},
|
||||||
|
ngraph::ParameterVector{data, in_low, in_high, out_low, out_high});
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||||
|
|
||||||
|
manager.run_passes(function);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(function));
|
||||||
|
|
||||||
|
const auto res = compare_functions(function, expected_function);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FQMulFusion_NonConstInputs, FQ_out_high_const) {
|
||||||
|
const auto data = std::make_shared<ngraph::opset4::Parameter>(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{1, 3, 224, 224});
|
||||||
|
const auto in_low =
|
||||||
|
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||||
|
const auto in_high =
|
||||||
|
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||||
|
const auto out_low =
|
||||||
|
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||||
|
const auto out_high = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {100.0f});
|
||||||
|
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||||
|
data, in_low, in_high, out_low, out_high, 42);
|
||||||
|
|
||||||
|
const auto mul_value = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {3.14f});
|
||||||
|
const auto mul = std::make_shared<ngraph::opset4::Multiply>(fq, mul_value);
|
||||||
|
|
||||||
|
auto function = std::make_shared<ngraph::Function>(ngraph::OutputVector{mul},
|
||||||
|
ngraph::ParameterVector{data, in_low, in_high, out_low});
|
||||||
|
|
||||||
|
const auto expected_out_low = std::make_shared<ngraph::opset4::Multiply>(out_low, mul_value);
|
||||||
|
// this constant should be created by constant folding of the last FQ input
|
||||||
|
const auto expected_out_high = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {314.0f});
|
||||||
|
|
||||||
|
const auto expected_fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||||
|
data, in_low, in_high, expected_out_low, expected_out_high, 42);
|
||||||
|
|
||||||
|
const auto expected_function =
|
||||||
|
std::make_shared<ngraph::Function>(ngraph::OutputVector{expected_fq},
|
||||||
|
ngraph::ParameterVector{data, in_low, in_high, out_low});
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||||
|
|
||||||
|
manager.run_passes(function);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(function));
|
||||||
|
|
||||||
|
const auto res = compare_functions(function, expected_function);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FQMulFusion_FQ_Mul_inputs, FQ_out_to_mul_input_2) {
|
||||||
|
const auto data = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{1, 3, 224, 224}, {0.0f});
|
||||||
|
const auto in_low = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {-0.5f});
|
||||||
|
const auto in_high = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.5f});
|
||||||
|
const auto out_low = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.0f});
|
||||||
|
const auto out_high = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {100.0f});
|
||||||
|
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||||
|
data, in_low, in_high, out_low, out_high, 42);
|
||||||
|
|
||||||
|
const auto mul_value = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {3.14f});
|
||||||
|
// here the FQ's output is passed to the second input of the Mul operation
|
||||||
|
const auto mul = std::make_shared<ngraph::opset4::Multiply>(mul_value, fq);
|
||||||
|
|
||||||
|
auto function =
|
||||||
|
std::make_shared<ngraph::Function>(ngraph::OutputVector{mul}, ngraph::ParameterVector{});
|
||||||
|
|
||||||
|
const auto expected_out_low = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.0f});
|
||||||
|
const auto expected_out_high = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {314.0f});
|
||||||
|
|
||||||
|
const auto expected_fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||||
|
data, in_low, in_high, expected_out_low, expected_out_high, 42);
|
||||||
|
|
||||||
|
const auto expected_function = std::make_shared<ngraph::Function>(
|
||||||
|
ngraph::OutputVector{expected_fq}, ngraph::ParameterVector{});
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||||
|
|
||||||
|
manager.run_passes(function);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(function));
|
||||||
|
|
||||||
|
const auto res = compare_functions(function, expected_function);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
TEST(FQMulFusion_FQ_Mul_inputs, FQ_out_to_mul_input_2_param) {
|
||||||
|
const auto data = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{1, 3, 224, 224}, {0.0f});
|
||||||
|
const auto in_low = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {-0.5f});
|
||||||
|
const auto in_high = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.5f});
|
||||||
|
const auto out_low = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.0f});
|
||||||
|
// out_high is a parameter, which means it should not be constant folded
|
||||||
|
const auto out_high =
|
||||||
|
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||||
|
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||||
|
data, in_low, in_high, out_low, out_high, 42);
|
||||||
|
|
||||||
|
const auto mul_value = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {3.14f});
|
||||||
|
// and here the output of FQ is passed as the second input of Mul
|
||||||
|
const auto mul = std::make_shared<ngraph::opset4::Multiply>(mul_value, fq);
|
||||||
|
|
||||||
|
auto function = std::make_shared<ngraph::Function>(
|
||||||
|
ngraph::OutputVector{mul}, ngraph::ParameterVector{out_high});
|
||||||
|
|
||||||
|
const auto expected_out_low = ngraph::opset4::Constant::create(
|
||||||
|
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.0f});
|
||||||
|
const auto expected_out_high = std::make_shared<ngraph::opset4::Multiply>(out_high, mul_value);
|
||||||
|
|
||||||
|
const auto expected_fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||||
|
data, in_low, in_high, expected_out_low, expected_out_high, 42);
|
||||||
|
|
||||||
|
const auto expected_function = std::make_shared<ngraph::Function>(
|
||||||
|
ngraph::OutputVector{expected_fq}, ngraph::ParameterVector{out_high});
|
||||||
|
|
||||||
|
ngraph::pass::Manager manager;
|
||||||
|
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||||
|
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||||
|
|
||||||
|
manager.run_passes(function);
|
||||||
|
ASSERT_NO_THROW(check_rt_info(function));
|
||||||
|
|
||||||
|
const auto res = compare_functions(function, expected_function);
|
||||||
|
ASSERT_TRUE(res.first) << res.second;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
|
} // namespace LayerTestsDefinitions
|
Loading…
Reference in New Issue
Block a user