[LPT] FQ Decomposition improvement (#14203)
* [LPT] FQ Decomposition modified to create FQ without constants with several consumers * [LPT] Added subgraph test * CPUTestsBase: Added the ability to check only fusing results * [CPU] Added subgraph test * LPT review comments applied * CPUTestsUtils: added special string 'anytype' for selectedType check ignore
This commit is contained in:
committed by
GitHub
parent
2a1a2532b2
commit
b11adcdde4
@@ -257,6 +257,8 @@ public:
|
||||
float& updatedOutputLowValue,
|
||||
float& updatedOutputHighValue);
|
||||
|
||||
static ov::Output<ov::Node> getSingleConsumerConstant(const ov::Output<ov::Node>& output);
|
||||
|
||||
private:
|
||||
static std::shared_ptr<Node> foldFakeQuantize(
|
||||
const std::shared_ptr<opset1::FakeQuantize>& fq,
|
||||
|
||||
@@ -1060,8 +1060,8 @@ std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> NetworkHelper::decompos
|
||||
std::shared_ptr<ngraph::Node> newFQ = fold_fake_quantize(
|
||||
std::make_shared<op::TypeRelaxed<opset1::FakeQuantize>>(
|
||||
fq->input_value(0),
|
||||
fq->input_value(1),
|
||||
fq->input_value(2),
|
||||
getSingleConsumerConstant(fq->input_value(1)),
|
||||
getSingleConsumerConstant(fq->input_value(2)),
|
||||
newMin->output(0),
|
||||
newMax->output(0),
|
||||
fq->get_levels(),
|
||||
@@ -1124,15 +1124,17 @@ std::shared_ptr<opset1::FakeQuantize> NetworkHelper::updateFakeQuantize(
|
||||
float min,
|
||||
float max,
|
||||
const bool replace) {
|
||||
auto newMin = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, min);
|
||||
auto newMax = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, max);
|
||||
auto newInMin = getSingleConsumerConstant(fq->input_value(1));
|
||||
auto newInMax = getSingleConsumerConstant(fq->input_value(2));
|
||||
auto newOutMin = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, min);
|
||||
auto newOutMax = std::make_shared<opset1::Constant>(fq->get_output_element_type(0), Shape{}, max);
|
||||
|
||||
std::shared_ptr<opset1::FakeQuantize> newFQ = std::make_shared<ngraph::op::TypeRelaxed<opset1::FakeQuantize>>(
|
||||
fq->input_value(0),
|
||||
fq->input_value(1),
|
||||
fq->input_value(2),
|
||||
newMin->output(0),
|
||||
newMax->output(0),
|
||||
newInMin,
|
||||
newInMax,
|
||||
newOutMin->output(0),
|
||||
newOutMax->output(0),
|
||||
fq->get_levels(),
|
||||
fq->get_auto_broadcast());
|
||||
|
||||
@@ -1999,6 +2001,15 @@ void NetworkHelper::insertDequantizationAfter(
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
ov::Output<ov::Node> NetworkHelper::getSingleConsumerConstant(const ov::Output<ov::Node>& output) {
|
||||
const auto node = output.get_node();
|
||||
if (!ngraph::is_type<opset1::Constant>(node))
|
||||
THROW_IE_LPT_EXCEPTION(*node) << "getSingleConsumerConstant Expected Constant node type";
|
||||
return output.get_target_inputs().size() == 1
|
||||
? output
|
||||
: node->clone_with_new_inputs(node->input_values())->output(0);
|
||||
}
|
||||
} // namespace low_precision
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
||||
|
||||
@@ -35,7 +35,7 @@ ov_add_test_target(
|
||||
commonTestUtils
|
||||
lptNgraphFunctions
|
||||
gmock
|
||||
INCLUDES
|
||||
INCLUDES ${CMAKE_CURRENT_SOURCE_DIR}
|
||||
LABELS
|
||||
LP_TRANSFORMATIONS
|
||||
)
|
||||
@@ -0,0 +1,97 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "layer_transformation.hpp"
|
||||
#include "simple_low_precision_transformer.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <low_precision/fake_quantize_decomposition.hpp>
|
||||
#include <low_precision/low_precision.hpp>
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "lpt_ngraph_functions/common/builders.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
using namespace ngraph::builder::subgraph;
|
||||
|
||||
|
||||
class FQDecompositionWithSharedConstants : public LayerTransformation, public WithParamInterface<bool> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const bool addIntervalsAlignment = GetParam();
|
||||
const auto shape = ngraph::Shape{1, 3, 40, 40};
|
||||
const auto input_precision = ngraph::element::f32;
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset1::Parameter>(input_precision, shape);
|
||||
auto shared_il = opset1::Constant::create(input_precision, {}, {0.f});
|
||||
auto shared_ih = opset1::Constant::create(input_precision, {}, {25.5f});
|
||||
auto shared_ol = opset1::Constant::create(input_precision, {}, {0.f});
|
||||
auto shared_oh = opset1::Constant::create(input_precision, {}, {25.5f});
|
||||
auto fq_before = std::make_shared<opset1::FakeQuantize>(input, shared_il, shared_ih, shared_ol, shared_oh, 256);
|
||||
auto fq_after = std::make_shared<opset1::FakeQuantize>(fq_before, shared_il, shared_ih, shared_ol, shared_oh, 256);
|
||||
auto relu = std::make_shared<opset1::Relu>(fq_after);
|
||||
if (addIntervalsAlignment) {
|
||||
addAttributes({fq_before, fq_after}, {IntervalsAlignmentAttribute(IntervalsAlignmentSharedValue::Interval{0.f, 2.55f}, 256ul)});
|
||||
addAttributes({fq_after, relu}, {QuantizationAlignmentAttribute(true)});
|
||||
}
|
||||
ResultVector results{ std::make_shared<opset1::Result>(relu) };
|
||||
actualFunction = std::make_shared<Function>(results, ParameterVector{ input }, "FakeQuantizeFunction");
|
||||
}
|
||||
|
||||
SimpleLowPrecisionTransformer transform;
|
||||
transform.add<pass::low_precision::FakeQuantizeDecompositionTransformation, opset1::FakeQuantize>(LayerTransformation::createParamsU8I8());
|
||||
transform.transform(actualFunction);
|
||||
|
||||
{
|
||||
auto input = std::make_shared<opset1::Parameter>(input_precision, shape);
|
||||
auto fqStructure = FakeQuantizeOnData{256ul, Shape({}), { 0.f }, { 25.5f }, { 0.f }, { 255.f }, ngraph::element::u8};
|
||||
auto deqStructure = DequantizationOperations{{element::f32}, {}, {0.1f}};
|
||||
auto fq_before = makeFakeQuantizeTypeRelaxed(input, input_precision, fqStructure);
|
||||
auto dq_before = makeDequantization(fq_before, deqStructure);
|
||||
auto fq_after = makeFakeQuantizeTypeRelaxed(dq_before, input_precision, fqStructure);
|
||||
auto dq_after = makeDequantization(fq_after, deqStructure);
|
||||
auto relu = std::make_shared<opset1::Relu>(dq_after);
|
||||
ResultVector results{ std::make_shared<opset1::Result>(relu) };
|
||||
referenceFunction = std::make_shared<Function>(results, ParameterVector{ input }, "FakeQuantizeFunction");
|
||||
}
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<bool> obj) {
|
||||
const bool addIntervalsAlignment = obj.param;
|
||||
return addIntervalsAlignment ? "with_IntervalsAlignment" : "without_IntervalsAlignment";
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(FQDecompositionWithSharedConstants, FQDecompositionWithSharedConstants) {
|
||||
actualFunction->validate_nodes_and_infer_types();
|
||||
|
||||
auto comparator = FunctionsComparator::no_default();
|
||||
comparator.enable(FunctionsComparator::CmpValues::NODES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES);
|
||||
comparator.enable(FunctionsComparator::CmpValues::PRECISIONS);
|
||||
auto res = comparator.compare(actualFunction, referenceFunction);
|
||||
ASSERT_TRUE(res.valid) << res.message;
|
||||
|
||||
// additional check: FQ constants after transformation mustn't be shared
|
||||
for (const auto n : actualFunction->get_ordered_ops()) {
|
||||
if (ov::is_type<opset1::Constant>(n))
|
||||
EXPECT_EQ(n->get_output_target_inputs(0).size(), 1);
|
||||
}
|
||||
}
|
||||
namespace {
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
smoke_LPT,
|
||||
FQDecompositionWithSharedConstants,
|
||||
::testing::ValuesIn(std::vector<bool>{false, true}),
|
||||
FQDecompositionWithSharedConstants::getTestCaseName);
|
||||
} // namespace
|
||||
Reference in New Issue
Block a user