Backport of FQ+Mul transform to master (#2214)
* Backport of FQ+Mul transform to master * Accept any type of input to FQ in the transformation * Test the fusion when all FQ inputs are non-const * Fusion test when only one output limit is const * Test passing the output of FQ to second input of Mul
This commit is contained in:
parent
c13ec24e1e
commit
dda6d9136b
@ -0,0 +1,32 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <transformations_visibility.hpp>
|
||||
|
||||
#include <ngraph/pass/graph_rewrite.hpp>
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
||||
class TRANSFORMATIONS_API FakeQuantizeMulFusion;
|
||||
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
/**
|
||||
* @ingroup ie_transformation_common_api
|
||||
* @brief This transformation looks for a FQ + Mul pair in the graph and moves
|
||||
* the Mul operation above the FQ node. The last two inputs of FQ are multiplied
|
||||
* by the value that was originally below the FQ node.
|
||||
*/
|
||||
|
||||
class ngraph::pass::FakeQuantizeMulFusion : public ngraph::pass::MatcherPass {
|
||||
public:
|
||||
FakeQuantizeMulFusion();
|
||||
};
|
@ -0,0 +1,108 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "transformations/common_optimizations/fq_mul_fusion.hpp"
|
||||
#include "transformations/utils/utils.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include <ngraph/rt_info.hpp>
|
||||
|
||||
namespace {
|
||||
std::pair<ngraph::Output<ngraph::Node>, ngraph::Output<ngraph::Node>>
|
||||
get_adjusted_output_range(ngraph::Output<ngraph::Node> out_low,
|
||||
ngraph::Output<ngraph::Node> out_high,
|
||||
ngraph::Output<ngraph::Node> multiplier) {
|
||||
const auto mul_out_low = std::make_shared<ngraph::opset4::Multiply>(out_low, multiplier);
|
||||
const auto mul_out_high = std::make_shared<ngraph::opset4::Multiply>(out_high, multiplier);
|
||||
copy_runtime_info({out_low.get_node_shared_ptr(), multiplier.get_node_shared_ptr()},
|
||||
mul_out_low);
|
||||
copy_runtime_info({out_high.get_node_shared_ptr(), multiplier.get_node_shared_ptr()},
|
||||
mul_out_high);
|
||||
|
||||
ngraph::OutputVector new_out_low(1), new_out_high(1);
|
||||
|
||||
if (!mul_out_low->constant_fold(new_out_low, {out_low, multiplier})) {
|
||||
new_out_low[0] = mul_out_low;
|
||||
}
|
||||
|
||||
if (!mul_out_high->constant_fold(new_out_high, {out_high, multiplier})) {
|
||||
new_out_high[0] = mul_out_high;
|
||||
}
|
||||
|
||||
return {new_out_low[0], new_out_high[0]};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// This transformation multiplies the "output_low" and "output_high" inputs of the FQ operation
|
||||
// by the constant value that before transormation is used to multiply the output of FQ.
|
||||
// Both output_low and output_high are multiplied by the value represented as C (a constant) below.
|
||||
// In case any of the FQ inputs (out_L, out_H) is constant, it gets constant folded with C.
|
||||
//
|
||||
// data in_L in_H out_L out_H
|
||||
// | | | | |
|
||||
// | | | | | data in_L in_H out_L * C out_H * C
|
||||
// v v v v v | | | | |
|
||||
// +-------------------------+ | | | | |
|
||||
// | FakeQuantize | v v v v v
|
||||
// +-------------------------+ +-----------------------------------+
|
||||
// | =====> | FakeQuantize |
|
||||
// v +-----------------------------------+
|
||||
// +----------+ |
|
||||
// | Multiply | <--- C v
|
||||
// +----+-----+
|
||||
// |
|
||||
// v
|
||||
//
|
||||
|
||||
ngraph::pass::FakeQuantizeMulFusion::FakeQuantizeMulFusion() {
|
||||
const auto fq_output_low_p = ngraph::pattern::any_input();
|
||||
const auto fq_output_high_p = ngraph::pattern::any_input();
|
||||
|
||||
const auto fq_node_p = ngraph::pattern::wrap_type<opset4::FakeQuantize>(
|
||||
{ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(),
|
||||
ngraph::pattern::any_input(),
|
||||
fq_output_low_p,
|
||||
fq_output_high_p},
|
||||
pattern::consumers_count(1));
|
||||
|
||||
const auto mul_constant_p = ngraph::pattern::wrap_type<opset4::Constant>();
|
||||
const auto mul_node_p = ngraph::pattern::wrap_type<opset4::Multiply>(
|
||||
{fq_node_p, mul_constant_p}, pattern::consumers_count(1));
|
||||
|
||||
ngraph::matcher_pass_callback callback = [=](pattern::Matcher &m) {
|
||||
const auto& pattern_map = m.get_pattern_value_map();
|
||||
|
||||
const auto fq_node = pattern_map.at(fq_node_p).get_node_shared_ptr();
|
||||
|
||||
const auto original_output_low = pattern_map.at(fq_output_low_p);
|
||||
const auto original_output_high = pattern_map.at(fq_output_high_p);
|
||||
const auto mul_constant = pattern_map.at(mul_constant_p);
|
||||
|
||||
const auto new_output_limits = get_adjusted_output_range(
|
||||
original_output_low, original_output_high, mul_constant);
|
||||
|
||||
const auto new_fq_node = fq_node->clone_with_new_inputs({fq_node->input_value(0),
|
||||
fq_node->input_value(1),
|
||||
fq_node->input_value(2),
|
||||
new_output_limits.first,
|
||||
new_output_limits.second});
|
||||
|
||||
const auto mul_node = pattern_map.at(mul_node_p).get_node_shared_ptr();
|
||||
replace_node(mul_node, new_fq_node);
|
||||
|
||||
new_fq_node->set_friendly_name(fq_node->get_friendly_name());
|
||||
copy_runtime_info({fq_node, mul_node}, new_fq_node);
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(mul_node_p,
|
||||
"FakeQuantizeMulFusion");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
@ -51,6 +51,7 @@
|
||||
#include <transformations/hswish_decomposition.hpp>
|
||||
#include <transformations/reduce_l1_decomposition.hpp>
|
||||
#include <transformations/reduce_l2_decomposition.hpp>
|
||||
#include <transformations/common_optimizations/fq_mul_fusion.hpp>
|
||||
|
||||
#include <ngraph/pass/constant_folding.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
@ -111,6 +112,9 @@ bool ngraph::pass::ConvertOpSet1ToLegacy::run_on_function(std::shared_ptr<ngraph
|
||||
manager.register_pass<ngraph::pass::GroupConvolutionBackpropDataMultiplyFusion>();
|
||||
manager.register_pass<ngraph::pass::ConstantFolding>();
|
||||
|
||||
// Multiply the thrird and fourth input instead of the output of FQ with all const inputs
|
||||
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
|
||||
// Convolution/Deconvolution/FullyConnected fusions
|
||||
auto convert_convolutions = manager.register_pass<ngraph::pass::GraphRewrite>();
|
||||
convert_convolutions->add_matcher<ngraph::pass::ConvertConvolution>();
|
||||
|
@ -0,0 +1,362 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
|
||||
#include <ie_core.hpp>
|
||||
|
||||
#include "common_test_utils/common_utils.hpp"
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "common_test_utils/test_common.hpp"
|
||||
#include "functional_test_utils/plugin_cache.hpp"
|
||||
|
||||
#include <ngraph/function.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
#include <ngraph/pass/manager.hpp>
|
||||
#include <transformations/common_optimizations/fq_mul_fusion.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
using FQMulFusionParams =
|
||||
std::tuple<ngraph::Shape, // FQ data shape
|
||||
ngraph::Shape, // in_* shape
|
||||
ngraph::Shape, // out_* shape
|
||||
ngraph::Shape, // Mul constant shape
|
||||
ngraph::Shape>; // Expected shape of the new out_* constants
|
||||
|
||||
class FQMulFusion : public testing::WithParamInterface<FQMulFusionParams>,
|
||||
public CommonTestUtils::TestsCommon {
|
||||
public:
|
||||
void SetUp() override {
|
||||
ngraph::Shape data_shape, in_shape, out_shape, mul_const_shape, expected_out_shape;
|
||||
std::tie(data_shape, in_shape, out_shape, mul_const_shape, expected_out_shape) =
|
||||
this->GetParam();
|
||||
|
||||
const auto data = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, data_shape, {0.0f});
|
||||
const auto in_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, in_shape, {-0.5f});
|
||||
const auto in_high = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, in_shape, {0.5f});
|
||||
const auto out_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, out_shape, {0.0f});
|
||||
const auto out_high = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, out_shape, {100.0f});
|
||||
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||
data, in_low, in_high, out_low, out_high, 255);
|
||||
|
||||
const auto mul_value = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, mul_const_shape, {3.14f});
|
||||
const auto mul = std::make_shared<ngraph::opset4::Multiply>(fq, mul_value);
|
||||
|
||||
m_function = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector{mul}, ngraph::ParameterVector{}, "FQMulFusion");
|
||||
|
||||
const auto expected_data = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, data_shape, {0.0f});
|
||||
const auto expected_in_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, in_shape, {-0.5f});
|
||||
const auto expected_in_high = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, in_shape, {0.5f});
|
||||
const auto expected_out_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, expected_out_shape, {0.0f});
|
||||
const auto expected_out_high = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, expected_out_shape, {314.0f});
|
||||
|
||||
const auto expected_fq =
|
||||
std::make_shared<ngraph::opset4::FakeQuantize>(expected_data,
|
||||
expected_in_low, expected_in_high, expected_out_low, expected_out_high, 255);
|
||||
|
||||
m_expected_function = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector{expected_fq}, ngraph::ParameterVector{}, "FQMulFusion_expected");
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> m_function;
|
||||
std::shared_ptr<ngraph::Function> m_expected_function;
|
||||
};
|
||||
|
||||
TEST_P(FQMulFusion, ExpectFusion) {
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
|
||||
manager.run_passes(m_function);
|
||||
ASSERT_NO_THROW(check_rt_info(m_function));
|
||||
|
||||
const auto res = compare_functions(m_function, m_expected_function);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
};
|
||||
|
||||
namespace {
|
||||
INSTANTIATE_TEST_CASE_P(ScalarFQParams_C6_4D_channel_0, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{64, 3, 7, 7}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{64, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{64, 1, 1, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(ScalarFQParams_C6_4D_channel_1, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{64, 3, 7, 7}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{1, 3, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 3, 1, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(ScalarFQParams_C6_scalar, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{64, 3, 7, 7}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQOutputs1D_C6_scalar, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{64, 3, 7, 7}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{1}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQOutputs_NHWC_C6_scalar, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{1, 7, 7, 3}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 3}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 3})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQOutputs_NCHW_C6_scalar, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{1, 3, 7, 7}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{1, 3, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{}),
|
||||
::testing::Values(ngraph::Shape{1, 3, 1, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQInputs_4D_with_channel_dimension, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQInputs_4D_per__multiplier_with_channel, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQInputs_4D_with_channel__multiplier_4D_per_tensor, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQInputs_4D__multiplier_channel_3rd_dim, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 3, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 3, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQOutputs_1D__multiplier_3D, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1}),
|
||||
::testing::Values(ngraph::Shape{1, 3, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 3, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQ_all_ones__multiplier_4D_with_channel, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 1, 1})));
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(FQInOUt_ones__multiplier_4D_with_channel, FQMulFusion,
|
||||
::testing::Combine(::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 1, 1, 1}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 3, 3}),
|
||||
::testing::Values(ngraph::Shape{1, 64, 3, 3})));
|
||||
|
||||
TEST(FQMulFusion_NonConstInputs, AllInputsNonConst) {
|
||||
const auto data = std::make_shared<ngraph::opset4::Parameter>(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{1, 3, 224, 224});
|
||||
const auto in_low =
|
||||
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||
const auto in_high =
|
||||
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||
const auto out_low =
|
||||
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||
const auto out_high =
|
||||
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||
data, in_low, in_high, out_low, out_high, 42);
|
||||
|
||||
const auto mul_value = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {3.14f});
|
||||
const auto mul = std::make_shared<ngraph::opset4::Multiply>(fq, mul_value);
|
||||
|
||||
auto function = std::make_shared<ngraph::Function>(ngraph::OutputVector{mul},
|
||||
ngraph::ParameterVector{data, in_low, in_high, out_low, out_high});
|
||||
|
||||
const auto expected_out_low = std::make_shared<ngraph::opset4::Multiply>(out_low, mul_value);
|
||||
const auto expected_out_high = std::make_shared<ngraph::opset4::Multiply>(out_high, mul_value);
|
||||
|
||||
const auto expected_fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||
data, in_low, in_high, expected_out_low, expected_out_high, 42);
|
||||
|
||||
const auto expected_function =
|
||||
std::make_shared<ngraph::Function>(ngraph::OutputVector{expected_fq},
|
||||
ngraph::ParameterVector{data, in_low, in_high, out_low, out_high});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
|
||||
manager.run_passes(function);
|
||||
ASSERT_NO_THROW(check_rt_info(function));
|
||||
|
||||
const auto res = compare_functions(function, expected_function);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(FQMulFusion_NonConstInputs, FQ_out_high_const) {
|
||||
const auto data = std::make_shared<ngraph::opset4::Parameter>(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{1, 3, 224, 224});
|
||||
const auto in_low =
|
||||
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||
const auto in_high =
|
||||
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||
const auto out_low =
|
||||
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||
const auto out_high = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {100.0f});
|
||||
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||
data, in_low, in_high, out_low, out_high, 42);
|
||||
|
||||
const auto mul_value = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {3.14f});
|
||||
const auto mul = std::make_shared<ngraph::opset4::Multiply>(fq, mul_value);
|
||||
|
||||
auto function = std::make_shared<ngraph::Function>(ngraph::OutputVector{mul},
|
||||
ngraph::ParameterVector{data, in_low, in_high, out_low});
|
||||
|
||||
const auto expected_out_low = std::make_shared<ngraph::opset4::Multiply>(out_low, mul_value);
|
||||
// this constant should be created by constant folding of the last FQ input
|
||||
const auto expected_out_high = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {314.0f});
|
||||
|
||||
const auto expected_fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||
data, in_low, in_high, expected_out_low, expected_out_high, 42);
|
||||
|
||||
const auto expected_function =
|
||||
std::make_shared<ngraph::Function>(ngraph::OutputVector{expected_fq},
|
||||
ngraph::ParameterVector{data, in_low, in_high, out_low});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
|
||||
manager.run_passes(function);
|
||||
ASSERT_NO_THROW(check_rt_info(function));
|
||||
|
||||
const auto res = compare_functions(function, expected_function);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(FQMulFusion_FQ_Mul_inputs, FQ_out_to_mul_input_2) {
|
||||
const auto data = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{1, 3, 224, 224}, {0.0f});
|
||||
const auto in_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {-0.5f});
|
||||
const auto in_high = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.5f});
|
||||
const auto out_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.0f});
|
||||
const auto out_high = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {100.0f});
|
||||
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||
data, in_low, in_high, out_low, out_high, 42);
|
||||
|
||||
const auto mul_value = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {3.14f});
|
||||
// here the FQ's output is passed to the second input of the Mul operation
|
||||
const auto mul = std::make_shared<ngraph::opset4::Multiply>(mul_value, fq);
|
||||
|
||||
auto function =
|
||||
std::make_shared<ngraph::Function>(ngraph::OutputVector{mul}, ngraph::ParameterVector{});
|
||||
|
||||
const auto expected_out_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.0f});
|
||||
const auto expected_out_high = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {314.0f});
|
||||
|
||||
const auto expected_fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||
data, in_low, in_high, expected_out_low, expected_out_high, 42);
|
||||
|
||||
const auto expected_function = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector{expected_fq}, ngraph::ParameterVector{});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
|
||||
manager.run_passes(function);
|
||||
ASSERT_NO_THROW(check_rt_info(function));
|
||||
|
||||
const auto res = compare_functions(function, expected_function);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
TEST(FQMulFusion_FQ_Mul_inputs, FQ_out_to_mul_input_2_param) {
|
||||
const auto data = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{1, 3, 224, 224}, {0.0f});
|
||||
const auto in_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {-0.5f});
|
||||
const auto in_high = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.5f});
|
||||
const auto out_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.0f});
|
||||
// out_high is a parameter, which means it should not be constant folded
|
||||
const auto out_high =
|
||||
std::make_shared<ngraph::opset4::Parameter>(ngraph::element::Type_t::f32, ngraph::Shape{});
|
||||
const auto fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||
data, in_low, in_high, out_low, out_high, 42);
|
||||
|
||||
const auto mul_value = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {3.14f});
|
||||
// and here the output of FQ is passed as the second input of Mul
|
||||
const auto mul = std::make_shared<ngraph::opset4::Multiply>(mul_value, fq);
|
||||
|
||||
auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector{mul}, ngraph::ParameterVector{out_high});
|
||||
|
||||
const auto expected_out_low = ngraph::opset4::Constant::create(
|
||||
ngraph::element::Type_t::f32, ngraph::Shape{}, {0.0f});
|
||||
const auto expected_out_high = std::make_shared<ngraph::opset4::Multiply>(out_high, mul_value);
|
||||
|
||||
const auto expected_fq = std::make_shared<ngraph::opset4::FakeQuantize>(
|
||||
data, in_low, in_high, expected_out_low, expected_out_high, 42);
|
||||
|
||||
const auto expected_function = std::make_shared<ngraph::Function>(
|
||||
ngraph::OutputVector{expected_fq}, ngraph::ParameterVector{out_high});
|
||||
|
||||
ngraph::pass::Manager manager;
|
||||
manager.register_pass<ngraph::pass::InitNodeInfo>();
|
||||
manager.register_pass<ngraph::pass::FakeQuantizeMulFusion>();
|
||||
|
||||
manager.run_passes(function);
|
||||
ASSERT_NO_THROW(check_rt_info(function));
|
||||
|
||||
const auto res = compare_functions(function, expected_function);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
Loading…
Reference in New Issue
Block a user