[LPT] FuseConvert transformation extension (#10558)
* [LPT] FuseConvert transformation extension * [LPT] Tests * [LPT] Cleanup & tests refactoring
This commit is contained in:
parent
d7ad1bd9cd
commit
5be402750a
@ -7,7 +7,6 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <ngraph/ngraph.hpp>
|
#include <ngraph/ngraph.hpp>
|
||||||
#include "layer_transformation.hpp"
|
#include "layer_transformation.hpp"
|
||||||
#include "low_precision/fuse_fake_quantize.hpp"
|
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
@ -7,7 +7,6 @@
|
|||||||
#include <memory>
|
#include <memory>
|
||||||
#include <ngraph/ngraph.hpp>
|
#include <ngraph/ngraph.hpp>
|
||||||
#include "layer_transformation.hpp"
|
#include "layer_transformation.hpp"
|
||||||
#include "low_precision/fuse_fake_quantize.hpp"
|
|
||||||
|
|
||||||
namespace ngraph {
|
namespace ngraph {
|
||||||
namespace pass {
|
namespace pass {
|
||||||
|
@ -1,30 +0,0 @@
|
|||||||
// Copyright (C) 2018-2022 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#pragma once
|
|
||||||
|
|
||||||
#include <memory>
|
|
||||||
#include <ngraph/ngraph.hpp>
|
|
||||||
#include "low_precision/layer_transformation.hpp"
|
|
||||||
|
|
||||||
namespace ngraph {
|
|
||||||
namespace pass {
|
|
||||||
namespace low_precision {
|
|
||||||
|
|
||||||
class LP_TRANSFORMATIONS_API FuseFakeQuantizeTransformation : public LayerTransformation {
|
|
||||||
public:
|
|
||||||
NGRAPH_RTTI_DECLARATION;
|
|
||||||
FuseFakeQuantizeTransformation(const Params& params);
|
|
||||||
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
|
|
||||||
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
|
|
||||||
|
|
||||||
private:
|
|
||||||
std::shared_ptr<opset1::FakeQuantize> handle(
|
|
||||||
TransformationContext& context,
|
|
||||||
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const;
|
|
||||||
};
|
|
||||||
|
|
||||||
} // namespace low_precision
|
|
||||||
} // namespace pass
|
|
||||||
} // namespace ngraph
|
|
@ -175,7 +175,8 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto data = fq::getDataNode(eltwise);
|
// issue #79980
|
||||||
|
const auto data = eltwise->get_input_size() == 1ul ? eltwise->get_input_node_shared_ptr(0) : fq::getDataNode(eltwise);
|
||||||
const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise);
|
const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise);
|
||||||
|
|
||||||
const auto newFakeQuantize = ov::as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
|
const auto newFakeQuantize = ov::as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
|
||||||
|
@ -23,8 +23,14 @@ FuseConvertTransformation::FuseConvertTransformation(const Params& params) : Lay
|
|||||||
auto multiply = pattern::wrap_type<opset1::Multiply>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
|
auto multiply = pattern::wrap_type<opset1::Multiply>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
|
||||||
auto subtract = pattern::wrap_type<opset1::Subtract>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
|
auto subtract = pattern::wrap_type<opset1::Subtract>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
|
||||||
auto add = pattern::wrap_type<opset1::Add>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
|
auto add = pattern::wrap_type<opset1::Add>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
|
||||||
|
auto fakeQuantize = pattern::wrap_type<opset1::FakeQuantize>({
|
||||||
|
pattern::wrap_type<opset1::Convert>({pattern::wrap_type<opset1::Constant>()}),
|
||||||
|
pattern::any_input(),
|
||||||
|
pattern::any_input(),
|
||||||
|
pattern::any_input(),
|
||||||
|
pattern::any_input()});
|
||||||
auto matcher = std::make_shared<ngraph::pattern::Matcher>(
|
auto matcher = std::make_shared<ngraph::pattern::Matcher>(
|
||||||
std::make_shared<pattern::op::Or>(OutputVector{ multiply, subtract, add }),
|
std::make_shared<pattern::op::Or>(OutputVector{ multiply, subtract, add, fakeQuantize }),
|
||||||
"FuseConvertTransformation");
|
"FuseConvertTransformation");
|
||||||
|
|
||||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||||
|
@ -1,193 +0,0 @@
|
|||||||
// Copyright (C) 2018-2022 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "low_precision/fuse_fake_quantize.hpp"
|
|
||||||
#include <memory>
|
|
||||||
#include <ngraph/ngraph.hpp>
|
|
||||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
|
||||||
#include "low_precision/common/ie_lpt_exception.hpp"
|
|
||||||
#include "low_precision/network_helper.hpp"
|
|
||||||
|
|
||||||
namespace ngraph {
|
|
||||||
namespace pass {
|
|
||||||
namespace low_precision {
|
|
||||||
|
|
||||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::FuseFakeQuantizeTransformation, "FuseFakeQuantizeTransformation", 0);
|
|
||||||
|
|
||||||
FuseFakeQuantizeTransformation::FuseFakeQuantizeTransformation(const Params& params) : LayerTransformation(params) {
|
|
||||||
auto matcher = pattern::wrap_type<opset1::FakeQuantize>();
|
|
||||||
|
|
||||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
|
||||||
auto op = m.get_match_root();
|
|
||||||
if (transformation_callback(op)) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
return transform(*context, m);
|
|
||||||
};
|
|
||||||
|
|
||||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matcher, "FuseFakeQuantizeTransformation");
|
|
||||||
this->register_matcher(m, callback);
|
|
||||||
}
|
|
||||||
|
|
||||||
bool FuseFakeQuantizeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) {
|
|
||||||
auto fakeQuantize = ov::as_type_ptr<ngraph::opset1::FakeQuantize>(m.get_match_root());
|
|
||||||
if (!fakeQuantize)
|
|
||||||
return false;
|
|
||||||
|
|
||||||
do {
|
|
||||||
fakeQuantize = handle(context, fakeQuantize);
|
|
||||||
} while (fakeQuantize != nullptr);
|
|
||||||
return true;
|
|
||||||
}
|
|
||||||
|
|
||||||
namespace fuse_fq {
|
|
||||||
namespace {
|
|
||||||
|
|
||||||
std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const PartialShape& targetPShape) {
|
|
||||||
assert(targetPShape.is_static());
|
|
||||||
assert(op->get_output_partial_shape(0).is_static());
|
|
||||||
const Shape targetShape = targetPShape.to_shape();
|
|
||||||
const Shape shape = op->get_output_shape(0);
|
|
||||||
|
|
||||||
if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
|
|
||||||
op = fold<opset1::Unsqueeze>(
|
|
||||||
op,
|
|
||||||
std::make_shared<opset1::Constant>(ngraph::element::i32, Shape{ 1 }, std::vector<size_t>({ 0ul })));
|
|
||||||
}
|
|
||||||
return op;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<Node> getDataNode(const std::shared_ptr<Node>& eltwise) {
|
|
||||||
if (!ov::is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) {
|
|
||||||
return eltwise->get_input_node_shared_ptr(0);
|
|
||||||
}
|
|
||||||
|
|
||||||
if (!ov::is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(1))) {
|
|
||||||
return eltwise->get_input_node_shared_ptr(1);
|
|
||||||
}
|
|
||||||
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>& eltwise) {
|
|
||||||
if (eltwise->get_input_size() != 2) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
std::shared_ptr<opset1::Constant> constant = ov::as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(1));
|
|
||||||
if (constant != nullptr) {
|
|
||||||
return constant;
|
|
||||||
}
|
|
||||||
|
|
||||||
return ov::as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
|
|
||||||
}
|
|
||||||
|
|
||||||
bool eltwiseWithConstant(const std::shared_ptr<Node>& eltwise) {
|
|
||||||
std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
|
|
||||||
if (constant == nullptr) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
Shape shape = constant->get_shape();
|
|
||||||
if ((!shape.empty()) && (shape_size(shape) != 1ul)) {
|
|
||||||
const auto eltwisePShape = eltwise->get_output_partial_shape(0);
|
|
||||||
if (eltwisePShape.rank().is_dynamic()) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
const size_t eltwiseOutRank = eltwisePShape.rank().get_length();
|
|
||||||
if ((eltwiseOutRank - shape.size()) > 1) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
if ((eltwiseOutRank - shape.size()) == 1ul) {
|
|
||||||
shape.insert(shape.begin(), 1ul);
|
|
||||||
}
|
|
||||||
|
|
||||||
for (size_t i = 2ul; i < shape.size(); ++i) {
|
|
||||||
if (shape[i] != 1ul) {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return getDataNode(eltwise) != nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace
|
|
||||||
} // namespace fuse_fq
|
|
||||||
|
|
||||||
std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
|
|
||||||
TransformationContext& context,
|
|
||||||
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const {
|
|
||||||
const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0);
|
|
||||||
|
|
||||||
std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
|
|
||||||
std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
|
|
||||||
|
|
||||||
std::shared_ptr<opset1::Constant> constant = fuse_fq::getConstant(eltwise);
|
|
||||||
if (ov::is_type<opset1::Multiply>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
|
||||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
|
||||||
constant :
|
|
||||||
foldConvert(constant, eltwise->get_output_element_type(0));
|
|
||||||
|
|
||||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
|
|
||||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
|
|
||||||
} else if (ov::is_type<opset1::Divide>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
|
||||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
|
||||||
constant :
|
|
||||||
foldConvert(constant, eltwise->get_output_element_type(0));
|
|
||||||
|
|
||||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
|
|
||||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
|
|
||||||
} else if (ov::is_type<opset1::Subtract>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
|
||||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
|
||||||
constant :
|
|
||||||
foldConvert(constant, eltwise->get_output_element_type(0));
|
|
||||||
|
|
||||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
|
|
||||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
|
|
||||||
} else if (ov::is_type<opset1::Add>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
|
||||||
if (ov::is_type<opset1::Convolution>(fuse_fq::getDataNode(eltwise)) ||
|
|
||||||
ov::is_type<opset1::GroupConvolution>(fuse_fq::getDataNode(eltwise))) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
|
||||||
constant :
|
|
||||||
foldConvert(constant, eltwise->get_output_element_type(0));
|
|
||||||
|
|
||||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
|
|
||||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
|
|
||||||
} else if (ov::is_type<opset1::Convert>(eltwise)) {
|
|
||||||
// issue #40611
|
|
||||||
if ((eltwise->get_input_element_type(0) == element::i32) && (eltwise->get_output_element_type(0) == element::f32)) {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto data = fuse_fq::getDataNode(eltwise);
|
|
||||||
const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise);
|
|
||||||
|
|
||||||
std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = ov::as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
|
|
||||||
data->output(outputIdx),
|
|
||||||
inputLowConst,
|
|
||||||
inputHightConst,
|
|
||||||
fakeQuantize->input_value(3),
|
|
||||||
fakeQuantize->input_value(4) }));
|
|
||||||
|
|
||||||
replace_node(fakeQuantize, newFakeQuantize);
|
|
||||||
NetworkHelper::copyInfo(fakeQuantize, newFakeQuantize);
|
|
||||||
return newFakeQuantize;
|
|
||||||
}
|
|
||||||
|
|
||||||
bool FuseFakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
|
|
||||||
return false;
|
|
||||||
}
|
|
||||||
|
|
||||||
} // namespace low_precision
|
|
||||||
} // namespace pass
|
|
||||||
} // namespace ngraph
|
|
@ -74,7 +74,6 @@
|
|||||||
#include "low_precision/convert.hpp"
|
#include "low_precision/convert.hpp"
|
||||||
#include "low_precision/fold_fake_quantize.hpp"
|
#include "low_precision/fold_fake_quantize.hpp"
|
||||||
#include "low_precision/fuse_convert.hpp"
|
#include "low_precision/fuse_convert.hpp"
|
||||||
#include "low_precision/fuse_fake_quantize.hpp"
|
|
||||||
#include "low_precision/fuse_subtract_to_fake_quantize.hpp"
|
#include "low_precision/fuse_subtract_to_fake_quantize.hpp"
|
||||||
#include "low_precision/fuse_multiply_to_fake_quantize.hpp"
|
#include "low_precision/fuse_multiply_to_fake_quantize.hpp"
|
||||||
#include "low_precision/multiply_to_group_convolution.hpp"
|
#include "low_precision/multiply_to_group_convolution.hpp"
|
||||||
|
@ -30,12 +30,14 @@ public:
|
|||||||
public:
|
public:
|
||||||
ngraph::element::Type inputPrecision;
|
ngraph::element::Type inputPrecision;
|
||||||
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
||||||
|
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize;
|
||||||
};
|
};
|
||||||
|
|
||||||
class Expected {
|
class Expected {
|
||||||
public:
|
public:
|
||||||
ngraph::element::Type inputPrecision;
|
ngraph::element::Type inputPrecision;
|
||||||
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
||||||
|
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize;
|
||||||
};
|
};
|
||||||
|
|
||||||
bool constInput;
|
bool constInput;
|
||||||
@ -58,6 +60,7 @@ public:
|
|||||||
inputShape,
|
inputShape,
|
||||||
testValues.actual.inputPrecision,
|
testValues.actual.inputPrecision,
|
||||||
testValues.actual.dequantization,
|
testValues.actual.dequantization,
|
||||||
|
testValues.actual.fakeQuantize,
|
||||||
testValues.constInput);
|
testValues.constInput);
|
||||||
|
|
||||||
SimpleLowPrecisionTransformer transformer;
|
SimpleLowPrecisionTransformer transformer;
|
||||||
@ -68,6 +71,7 @@ public:
|
|||||||
inputShape,
|
inputShape,
|
||||||
testValues.expected.inputPrecision,
|
testValues.expected.inputPrecision,
|
||||||
testValues.expected.dequantization,
|
testValues.expected.dequantization,
|
||||||
|
testValues.expected.fakeQuantize,
|
||||||
testValues.constInput);
|
testValues.constInput);
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -77,9 +81,13 @@ public:
|
|||||||
|
|
||||||
std::ostringstream result;
|
std::ostringstream result;
|
||||||
result <<
|
result <<
|
||||||
inputShape << "_" <<
|
"IS_" << inputShape << "_" <<
|
||||||
testValues.actual.inputPrecision << "_" <<
|
"AIP_" << testValues.actual.inputPrecision << "_" <<
|
||||||
testValues.actual.dequantization << "_" <<
|
"ADEQ_" << testValues.actual.dequantization << "_" <<
|
||||||
|
"AFQ_" << testValues.actual.fakeQuantize << "_" <<
|
||||||
|
"EIP_" << testValues.expected.inputPrecision << "_" <<
|
||||||
|
"EDEQ_" << testValues.expected.dequantization << "_" <<
|
||||||
|
"EFQ_" << testValues.expected.fakeQuantize << "_" <<
|
||||||
testValues.constInput;
|
testValues.constInput;
|
||||||
return result.str();
|
return result.str();
|
||||||
}
|
}
|
||||||
@ -111,7 +119,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
|
|||||||
{ ngraph::element::f32 },
|
{ ngraph::element::f32 },
|
||||||
{1.f},
|
{1.f},
|
||||||
{0.45f}
|
{0.45f}
|
||||||
}
|
},
|
||||||
|
{}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ngraph::element::u8,
|
ngraph::element::u8,
|
||||||
@ -119,7 +128,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
|
|||||||
{},
|
{},
|
||||||
DequantizationOperations::Subtract({1.f}, ngraph::element::f32).setConstantPrecision(ngraph::element::f32),
|
DequantizationOperations::Subtract({1.f}, ngraph::element::f32).setConstantPrecision(ngraph::element::f32),
|
||||||
{0.45f}
|
{0.45f}
|
||||||
}
|
},
|
||||||
|
{}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// fuse to multiply
|
// fuse to multiply
|
||||||
@ -132,7 +142,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
|
|||||||
{ ngraph::element::f32 },
|
{ ngraph::element::f32 },
|
||||||
{},
|
{},
|
||||||
{0.45f}
|
{0.45f}
|
||||||
}
|
},
|
||||||
|
{}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ngraph::element::u8,
|
ngraph::element::u8,
|
||||||
@ -140,7 +151,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
|
|||||||
{},
|
{},
|
||||||
{},
|
{},
|
||||||
DequantizationOperations::Multiply({0.45f}, ngraph::element::f32).setConstantPrecision(ngraph::element::f32)
|
DequantizationOperations::Multiply({0.45f}, ngraph::element::f32).setConstantPrecision(ngraph::element::f32)
|
||||||
}
|
},
|
||||||
|
{}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
// Convert with unexpected precision
|
// Convert with unexpected precision
|
||||||
@ -149,11 +161,13 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
|
|||||||
LayerTransformation::createParamsU8I8(),
|
LayerTransformation::createParamsU8I8(),
|
||||||
{
|
{
|
||||||
ngraph::element::f32,
|
ngraph::element::f32,
|
||||||
{{ ngraph::element::i32 }, {}, {3.f}}
|
{{ ngraph::element::i32 }, {}, {3.f}},
|
||||||
|
{}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ngraph::element::f32,
|
ngraph::element::f32,
|
||||||
{{ ngraph::element::i32 }, {}, {3.f}}
|
{{ ngraph::element::i32 }, {}, {3.f}},
|
||||||
|
{}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
@ -173,6 +187,27 @@ const std::vector<ngraph::PartialShape> inputShapes = {
|
|||||||
};
|
};
|
||||||
|
|
||||||
const std::vector<FuseConvertTransformationTestValues> testValuesWithConstant = {
|
const std::vector<FuseConvertTransformationTestValues> testValuesWithConstant = {
|
||||||
|
// Constant
|
||||||
|
// |
|
||||||
|
// Convert Const Const Const Const
|
||||||
|
// \ \ | / /
|
||||||
|
// \ \ | / /
|
||||||
|
// FakeQuantize
|
||||||
|
//
|
||||||
|
{
|
||||||
|
true,
|
||||||
|
LayerTransformation::createParamsU8I8(),
|
||||||
|
{
|
||||||
|
ngraph::element::u8,
|
||||||
|
{{ngraph::element::f32}, {}, {}},
|
||||||
|
{ 256, {}, {0.f}, {0.1f}, {0.f}, {0.1f}, ov::element::f32}
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ngraph::element::f32,
|
||||||
|
{},
|
||||||
|
{ 256, {}, {0.f}, {0.1f}, {0.f}, {0.1f}, ov::element::f32}
|
||||||
|
}
|
||||||
|
},
|
||||||
// fuse to const
|
// fuse to const
|
||||||
{
|
{
|
||||||
true,
|
true,
|
||||||
@ -183,7 +218,8 @@ const std::vector<FuseConvertTransformationTestValues> testValuesWithConstant =
|
|||||||
{ ngraph::element::f32 },
|
{ ngraph::element::f32 },
|
||||||
{1.f},
|
{1.f},
|
||||||
{0.45f}
|
{0.45f}
|
||||||
}
|
},
|
||||||
|
{}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
ngraph::element::f32,
|
ngraph::element::f32,
|
||||||
@ -191,7 +227,8 @@ const std::vector<FuseConvertTransformationTestValues> testValuesWithConstant =
|
|||||||
{},
|
{},
|
||||||
{1.f},
|
{1.f},
|
||||||
{0.45f}
|
{0.45f}
|
||||||
}
|
},
|
||||||
|
{}
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
|
|
||||||
#include <transformations/utils/utils.hpp>
|
#include <transformations/utils/utils.hpp>
|
||||||
#include <transformations/init_node_info.hpp>
|
#include <transformations/init_node_info.hpp>
|
||||||
#include <low_precision/fuse_fake_quantize.hpp>
|
#include <low_precision/fake_quantize.hpp>
|
||||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||||
|
|
||||||
@ -62,7 +62,7 @@ public:
|
|||||||
testValues.actual.fakeQuantizeOnData);
|
testValues.actual.fakeQuantizeOnData);
|
||||||
|
|
||||||
SimpleLowPrecisionTransformer transformer;
|
SimpleLowPrecisionTransformer transformer;
|
||||||
transformer.add<ngraph::pass::low_precision::FuseFakeQuantizeTransformation, ngraph::opset1::FakeQuantize>(testValues.params);
|
transformer.add<ngraph::pass::low_precision::FakeQuantizeTransformation, ngraph::opset1::FakeQuantize>(testValues.params);
|
||||||
transformer.transform(actualFunction);
|
transformer.transform(actualFunction);
|
||||||
|
|
||||||
referenceFunction = ngraph::builder::subgraph::FuseFakeQuantizeFunction::get(
|
referenceFunction = ngraph::builder::subgraph::FuseFakeQuantizeFunction::get(
|
||||||
|
@ -35,7 +35,6 @@
|
|||||||
|
|
||||||
// cleanup transformations
|
// cleanup transformations
|
||||||
#include "low_precision/fuse_convert.hpp"
|
#include "low_precision/fuse_convert.hpp"
|
||||||
#include "low_precision/fuse_fake_quantize.hpp"
|
|
||||||
#include "low_precision/fuse_subtract_to_fake_quantize.hpp"
|
#include "low_precision/fuse_subtract_to_fake_quantize.hpp"
|
||||||
#include "low_precision/fuse_multiply_to_fake_quantize.hpp"
|
#include "low_precision/fuse_multiply_to_fake_quantize.hpp"
|
||||||
#include "low_precision/multiply_to_group_convolution.hpp"
|
#include "low_precision/multiply_to_group_convolution.hpp"
|
||||||
|
@ -117,10 +117,16 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Convert& convert) {
|
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Convert& convert) {
|
||||||
|
if (convert.empty()) {
|
||||||
|
return out << "{}";
|
||||||
|
}
|
||||||
return out << "_" << (convert.outPrecision != element::undefined ? convert.outPrecision.get_type_name() : "");
|
return out << "_" << (convert.outPrecision != element::undefined ? convert.outPrecision.get_type_name() : "");
|
||||||
}
|
}
|
||||||
|
|
||||||
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Subtract& subtract) {
|
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Subtract& subtract) {
|
||||||
|
if (subtract.empty()) {
|
||||||
|
return out << "{}";
|
||||||
|
}
|
||||||
return out << "_" <<
|
return out << "_" <<
|
||||||
subtract.values << "_" <<
|
subtract.values << "_" <<
|
||||||
subtract.outPrecision << "_" <<
|
subtract.outPrecision << "_" <<
|
||||||
@ -132,6 +138,9 @@ inline std::ostream& operator<<(std::ostream& out, const DequantizationOperation
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Multiply& multiply) {
|
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Multiply& multiply) {
|
||||||
|
if (multiply.empty()) {
|
||||||
|
return out << "{}";
|
||||||
|
}
|
||||||
return out << "_" <<
|
return out << "_" <<
|
||||||
multiply.values << "_" <<
|
multiply.values << "_" <<
|
||||||
multiply.outPrecision << "_" <<
|
multiply.outPrecision << "_" <<
|
||||||
@ -142,6 +151,9 @@ inline std::ostream& operator<<(std::ostream& out, const DequantizationOperation
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations& data) {
|
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations& data) {
|
||||||
|
if (data.empty()) {
|
||||||
|
return out << "{}";
|
||||||
|
}
|
||||||
return out << "_" << data.convert << "_" << data.subtract << "_" << data.multiply;
|
return out << "_" << data.convert << "_" << data.subtract << "_" << data.multiply;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -54,6 +54,9 @@ inline std::ostream& operator<<(std::ostream& os, const std::vector<float>& valu
|
|||||||
}
|
}
|
||||||
|
|
||||||
inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnData& data) {
|
inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnData& data) {
|
||||||
|
if (data.empty()) {
|
||||||
|
return out << "{}";
|
||||||
|
}
|
||||||
return out << "_" << data.quantizationLevel << data.constantShape << "_" << data.inputLowValues << "_" << data.inputHighValues <<
|
return out << "_" << data.quantizationLevel << data.constantShape << "_" << data.inputLowValues << "_" << data.inputHighValues <<
|
||||||
"_" << data.outputLowValues << "_" << data.outputHighValues << "_" <<
|
"_" << data.outputLowValues << "_" << data.outputHighValues << "_" <<
|
||||||
(data.outputPrecision == ngraph::element::undefined ? "" : data.outputPrecision.get_type_name());
|
(data.outputPrecision == ngraph::element::undefined ? "" : data.outputPrecision.get_type_name());
|
||||||
@ -89,6 +92,9 @@ public:
|
|||||||
};
|
};
|
||||||
|
|
||||||
inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnDataWithConstant& data) {
|
inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnDataWithConstant& data) {
|
||||||
|
if (data.empty()) {
|
||||||
|
return out << "{}";
|
||||||
|
}
|
||||||
return out << "_" << data.quantizationLevel <<
|
return out << "_" << data.quantizationLevel <<
|
||||||
(data.constantShapes.empty() ? ngraph::Shape{} : data.constantShapes[0]) << "_" <<
|
(data.constantShapes.empty() ? ngraph::Shape{} : data.constantShapes[0]) << "_" <<
|
||||||
data.inputLowValues << "_" << data.inputHighValues << "_" <<
|
data.inputLowValues << "_" << data.inputHighValues << "_" <<
|
||||||
|
@ -20,6 +20,7 @@ public:
|
|||||||
const ngraph::PartialShape& inputShape,
|
const ngraph::PartialShape& inputShape,
|
||||||
const ngraph::element::Type inputPrecision,
|
const ngraph::element::Type inputPrecision,
|
||||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||||
|
const ngraph::builder::subgraph::FakeQuantizeOnData& fakeQuantize,
|
||||||
const bool constInput);
|
const bool constInput);
|
||||||
|
|
||||||
static std::shared_ptr<ngraph::Function> getWithFQ(
|
static std::shared_ptr<ngraph::Function> getWithFQ(
|
||||||
|
@ -16,6 +16,7 @@ std::shared_ptr<ngraph::Function> FuseConvertFunction::get(
|
|||||||
const ngraph::PartialShape& inputShape,
|
const ngraph::PartialShape& inputShape,
|
||||||
const ngraph::element::Type inputPrecision,
|
const ngraph::element::Type inputPrecision,
|
||||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||||
|
const ngraph::builder::subgraph::FakeQuantizeOnData& fakeQuantize,
|
||||||
const bool constInput) {
|
const bool constInput) {
|
||||||
std::shared_ptr<Node> parent;
|
std::shared_ptr<Node> parent;
|
||||||
std::shared_ptr<op::Parameter> input;
|
std::shared_ptr<op::Parameter> input;
|
||||||
@ -28,14 +29,19 @@ std::shared_ptr<ngraph::Function> FuseConvertFunction::get(
|
|||||||
parent = input;
|
parent = input;
|
||||||
}
|
}
|
||||||
|
|
||||||
const std::shared_ptr<Node> dequantizationOp = makeDequantization(parent, dequantization);
|
parent = makeDequantization(parent, dequantization);
|
||||||
dequantizationOp->set_friendly_name("output");
|
|
||||||
|
if (!fakeQuantize.empty()) {
|
||||||
|
parent = makeFakeQuantize(parent, fakeQuantize.outputPrecision, fakeQuantize);
|
||||||
|
}
|
||||||
|
|
||||||
|
parent->set_friendly_name("output");
|
||||||
|
|
||||||
auto parameters = constInput ?
|
auto parameters = constInput ?
|
||||||
ngraph::ParameterVector{}:
|
ngraph::ParameterVector{}:
|
||||||
ngraph::ParameterVector{ input };
|
ngraph::ParameterVector{ input };
|
||||||
|
|
||||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(dequantizationOp) };
|
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(parent)};
|
||||||
return std::make_shared<ngraph::Function>(results, parameters, "FuseConvertFunction");
|
return std::make_shared<ngraph::Function>(results, parameters, "FuseConvertFunction");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user