[LPT] FuseConvert transformation extension (#10558)
* [LPT] FuseConvert transformation extension * [LPT] Tests * [LPT] Cleanup & tests refactoring
This commit is contained in:
parent
d7ad1bd9cd
commit
5be402750a
@ -7,7 +7,6 @@
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include "layer_transformation.hpp"
|
||||
#include "low_precision/fuse_fake_quantize.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
@ -7,7 +7,6 @@
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include "layer_transformation.hpp"
|
||||
#include "low_precision/fuse_fake_quantize.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
|
@ -1,30 +0,0 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include "low_precision/layer_transformation.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
namespace low_precision {
|
||||
|
||||
class LP_TRANSFORMATIONS_API FuseFakeQuantizeTransformation : public LayerTransformation {
|
||||
public:
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
FuseFakeQuantizeTransformation(const Params& params);
|
||||
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
|
||||
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
|
||||
|
||||
private:
|
||||
std::shared_ptr<opset1::FakeQuantize> handle(
|
||||
TransformationContext& context,
|
||||
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const;
|
||||
};
|
||||
|
||||
} // namespace low_precision
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
@ -175,7 +175,8 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto data = fq::getDataNode(eltwise);
|
||||
// issue #79980
|
||||
const auto data = eltwise->get_input_size() == 1ul ? eltwise->get_input_node_shared_ptr(0) : fq::getDataNode(eltwise);
|
||||
const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise);
|
||||
|
||||
const auto newFakeQuantize = ov::as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
|
||||
|
@ -23,8 +23,14 @@ FuseConvertTransformation::FuseConvertTransformation(const Params& params) : Lay
|
||||
auto multiply = pattern::wrap_type<opset1::Multiply>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
|
||||
auto subtract = pattern::wrap_type<opset1::Subtract>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
|
||||
auto add = pattern::wrap_type<opset1::Add>({ pattern::wrap_type<opset1::Convert>(), pattern::wrap_type<opset1::Constant>() });
|
||||
auto fakeQuantize = pattern::wrap_type<opset1::FakeQuantize>({
|
||||
pattern::wrap_type<opset1::Convert>({pattern::wrap_type<opset1::Constant>()}),
|
||||
pattern::any_input(),
|
||||
pattern::any_input(),
|
||||
pattern::any_input(),
|
||||
pattern::any_input()});
|
||||
auto matcher = std::make_shared<ngraph::pattern::Matcher>(
|
||||
std::make_shared<pattern::op::Or>(OutputVector{ multiply, subtract, add }),
|
||||
std::make_shared<pattern::op::Or>(OutputVector{ multiply, subtract, add, fakeQuantize }),
|
||||
"FuseConvertTransformation");
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
|
@ -1,193 +0,0 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "low_precision/fuse_fake_quantize.hpp"
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
#include "low_precision/common/ie_lpt_exception.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
namespace low_precision {
|
||||
|
||||
NGRAPH_RTTI_DEFINITION(ngraph::pass::low_precision::FuseFakeQuantizeTransformation, "FuseFakeQuantizeTransformation", 0);
|
||||
|
||||
FuseFakeQuantizeTransformation::FuseFakeQuantizeTransformation(const Params& params) : LayerTransformation(params) {
|
||||
auto matcher = pattern::wrap_type<opset1::FakeQuantize>();
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
auto op = m.get_match_root();
|
||||
if (transformation_callback(op)) {
|
||||
return false;
|
||||
}
|
||||
return transform(*context, m);
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(matcher, "FuseFakeQuantizeTransformation");
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
bool FuseFakeQuantizeTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) {
|
||||
auto fakeQuantize = ov::as_type_ptr<ngraph::opset1::FakeQuantize>(m.get_match_root());
|
||||
if (!fakeQuantize)
|
||||
return false;
|
||||
|
||||
do {
|
||||
fakeQuantize = handle(context, fakeQuantize);
|
||||
} while (fakeQuantize != nullptr);
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace fuse_fq {
|
||||
namespace {
|
||||
|
||||
std::shared_ptr<Node> updateShape(std::shared_ptr<Node> op, const PartialShape& targetPShape) {
|
||||
assert(targetPShape.is_static());
|
||||
assert(op->get_output_partial_shape(0).is_static());
|
||||
const Shape targetShape = targetPShape.to_shape();
|
||||
const Shape shape = op->get_output_shape(0);
|
||||
|
||||
if ((shape.size() < targetShape.size()) && (shape.size() > 1ul)) {
|
||||
op = fold<opset1::Unsqueeze>(
|
||||
op,
|
||||
std::make_shared<opset1::Constant>(ngraph::element::i32, Shape{ 1 }, std::vector<size_t>({ 0ul })));
|
||||
}
|
||||
return op;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> getDataNode(const std::shared_ptr<Node>& eltwise) {
|
||||
if (!ov::is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(0))) {
|
||||
return eltwise->get_input_node_shared_ptr(0);
|
||||
}
|
||||
|
||||
if (!ov::is_type<opset1::Constant>(eltwise->get_input_node_shared_ptr(1))) {
|
||||
return eltwise->get_input_node_shared_ptr(1);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<opset1::Constant> getConstant(const std::shared_ptr<Node>& eltwise) {
|
||||
if (eltwise->get_input_size() != 2) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::shared_ptr<opset1::Constant> constant = ov::as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(1));
|
||||
if (constant != nullptr) {
|
||||
return constant;
|
||||
}
|
||||
|
||||
return ov::as_type_ptr<opset1::Constant>(eltwise->get_input_node_shared_ptr(0));
|
||||
}
|
||||
|
||||
bool eltwiseWithConstant(const std::shared_ptr<Node>& eltwise) {
|
||||
std::shared_ptr<opset1::Constant> constant = getConstant(eltwise);
|
||||
if (constant == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
Shape shape = constant->get_shape();
|
||||
if ((!shape.empty()) && (shape_size(shape) != 1ul)) {
|
||||
const auto eltwisePShape = eltwise->get_output_partial_shape(0);
|
||||
if (eltwisePShape.rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const size_t eltwiseOutRank = eltwisePShape.rank().get_length();
|
||||
if ((eltwiseOutRank - shape.size()) > 1) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if ((eltwiseOutRank - shape.size()) == 1ul) {
|
||||
shape.insert(shape.begin(), 1ul);
|
||||
}
|
||||
|
||||
for (size_t i = 2ul; i < shape.size(); ++i) {
|
||||
if (shape[i] != 1ul) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return getDataNode(eltwise) != nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace fuse_fq
|
||||
|
||||
std::shared_ptr<opset1::FakeQuantize> FuseFakeQuantizeTransformation::handle(
|
||||
TransformationContext& context,
|
||||
const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize) const {
|
||||
const std::shared_ptr<Node> eltwise = fakeQuantize->get_input_node_shared_ptr(0);
|
||||
|
||||
std::shared_ptr<Node> inputLowConst = fakeQuantize->get_input_node_shared_ptr(1);
|
||||
std::shared_ptr<Node> inputHightConst = fakeQuantize->get_input_node_shared_ptr(2);
|
||||
|
||||
std::shared_ptr<opset1::Constant> constant = fuse_fq::getConstant(eltwise);
|
||||
if (ov::is_type<opset1::Multiply>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
||||
constant :
|
||||
foldConvert(constant, eltwise->get_output_element_type(0));
|
||||
|
||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Divide>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
|
||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Divide>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
|
||||
} else if (ov::is_type<opset1::Divide>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
||||
constant :
|
||||
foldConvert(constant, eltwise->get_output_element_type(0));
|
||||
|
||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
|
||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Multiply>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
|
||||
} else if (ov::is_type<opset1::Subtract>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
||||
constant :
|
||||
foldConvert(constant, eltwise->get_output_element_type(0));
|
||||
|
||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Add>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
|
||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Add>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
|
||||
} else if (ov::is_type<opset1::Add>(eltwise) && fuse_fq::eltwiseWithConstant(eltwise)) {
|
||||
if (ov::is_type<opset1::Convolution>(fuse_fq::getDataNode(eltwise)) ||
|
||||
ov::is_type<opset1::GroupConvolution>(fuse_fq::getDataNode(eltwise))) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto value = constant->get_output_element_type(0) == eltwise->get_output_element_type(0) ?
|
||||
constant :
|
||||
foldConvert(constant, eltwise->get_output_element_type(0));
|
||||
|
||||
inputLowConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputLowConst, value), fakeQuantize->get_output_partial_shape(0));
|
||||
inputHightConst = fuse_fq::updateShape(fold<opset1::Subtract>(inputHightConst, value), fakeQuantize->get_output_partial_shape(0));
|
||||
} else if (ov::is_type<opset1::Convert>(eltwise)) {
|
||||
// issue #40611
|
||||
if ((eltwise->get_input_element_type(0) == element::i32) && (eltwise->get_output_element_type(0) == element::f32)) {
|
||||
return nullptr;
|
||||
}
|
||||
} else {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const auto data = fuse_fq::getDataNode(eltwise);
|
||||
const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise);
|
||||
|
||||
std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = ov::as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
|
||||
data->output(outputIdx),
|
||||
inputLowConst,
|
||||
inputHightConst,
|
||||
fakeQuantize->input_value(3),
|
||||
fakeQuantize->input_value(4) }));
|
||||
|
||||
replace_node(fakeQuantize, newFakeQuantize);
|
||||
NetworkHelper::copyInfo(fakeQuantize, newFakeQuantize);
|
||||
return newFakeQuantize;
|
||||
}
|
||||
|
||||
bool FuseFakeQuantizeTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
|
||||
return false;
|
||||
}
|
||||
|
||||
} // namespace low_precision
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
@ -74,7 +74,6 @@
|
||||
#include "low_precision/convert.hpp"
|
||||
#include "low_precision/fold_fake_quantize.hpp"
|
||||
#include "low_precision/fuse_convert.hpp"
|
||||
#include "low_precision/fuse_fake_quantize.hpp"
|
||||
#include "low_precision/fuse_subtract_to_fake_quantize.hpp"
|
||||
#include "low_precision/fuse_multiply_to_fake_quantize.hpp"
|
||||
#include "low_precision/multiply_to_group_convolution.hpp"
|
||||
|
@ -30,12 +30,14 @@ public:
|
||||
public:
|
||||
ngraph::element::Type inputPrecision;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize;
|
||||
};
|
||||
|
||||
class Expected {
|
||||
public:
|
||||
ngraph::element::Type inputPrecision;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize;
|
||||
};
|
||||
|
||||
bool constInput;
|
||||
@ -58,6 +60,7 @@ public:
|
||||
inputShape,
|
||||
testValues.actual.inputPrecision,
|
||||
testValues.actual.dequantization,
|
||||
testValues.actual.fakeQuantize,
|
||||
testValues.constInput);
|
||||
|
||||
SimpleLowPrecisionTransformer transformer;
|
||||
@ -68,6 +71,7 @@ public:
|
||||
inputShape,
|
||||
testValues.expected.inputPrecision,
|
||||
testValues.expected.dequantization,
|
||||
testValues.expected.fakeQuantize,
|
||||
testValues.constInput);
|
||||
}
|
||||
|
||||
@ -77,9 +81,13 @@ public:
|
||||
|
||||
std::ostringstream result;
|
||||
result <<
|
||||
inputShape << "_" <<
|
||||
testValues.actual.inputPrecision << "_" <<
|
||||
testValues.actual.dequantization << "_" <<
|
||||
"IS_" << inputShape << "_" <<
|
||||
"AIP_" << testValues.actual.inputPrecision << "_" <<
|
||||
"ADEQ_" << testValues.actual.dequantization << "_" <<
|
||||
"AFQ_" << testValues.actual.fakeQuantize << "_" <<
|
||||
"EIP_" << testValues.expected.inputPrecision << "_" <<
|
||||
"EDEQ_" << testValues.expected.dequantization << "_" <<
|
||||
"EFQ_" << testValues.expected.fakeQuantize << "_" <<
|
||||
testValues.constInput;
|
||||
return result.str();
|
||||
}
|
||||
@ -111,7 +119,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
|
||||
{ ngraph::element::f32 },
|
||||
{1.f},
|
||||
{0.45f}
|
||||
}
|
||||
},
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
@ -119,7 +128,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
|
||||
{},
|
||||
DequantizationOperations::Subtract({1.f}, ngraph::element::f32).setConstantPrecision(ngraph::element::f32),
|
||||
{0.45f}
|
||||
}
|
||||
},
|
||||
{}
|
||||
}
|
||||
},
|
||||
// fuse to multiply
|
||||
@ -132,7 +142,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
|
||||
{ ngraph::element::f32 },
|
||||
{},
|
||||
{0.45f}
|
||||
}
|
||||
},
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
@ -140,7 +151,8 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
|
||||
{},
|
||||
{},
|
||||
DequantizationOperations::Multiply({0.45f}, ngraph::element::f32).setConstantPrecision(ngraph::element::f32)
|
||||
}
|
||||
},
|
||||
{}
|
||||
}
|
||||
},
|
||||
// Convert with unexpected precision
|
||||
@ -149,11 +161,13 @@ const std::vector<FuseConvertTransformationTestValues> testValues = {
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{{ ngraph::element::i32 }, {}, {3.f}}
|
||||
{{ ngraph::element::i32 }, {}, {3.f}},
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{{ ngraph::element::i32 }, {}, {3.f}}
|
||||
{{ ngraph::element::i32 }, {}, {3.f}},
|
||||
{}
|
||||
}
|
||||
},
|
||||
};
|
||||
@ -173,6 +187,27 @@ const std::vector<ngraph::PartialShape> inputShapes = {
|
||||
};
|
||||
|
||||
const std::vector<FuseConvertTransformationTestValues> testValuesWithConstant = {
|
||||
// Constant
|
||||
// |
|
||||
// Convert Const Const Const Const
|
||||
// \ \ | / /
|
||||
// \ \ | / /
|
||||
// FakeQuantize
|
||||
//
|
||||
{
|
||||
true,
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {}, {}},
|
||||
{ 256, {}, {0.f}, {0.1f}, {0.f}, {0.1f}, ov::element::f32}
|
||||
},
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{},
|
||||
{ 256, {}, {0.f}, {0.1f}, {0.f}, {0.1f}, ov::element::f32}
|
||||
}
|
||||
},
|
||||
// fuse to const
|
||||
{
|
||||
true,
|
||||
@ -183,7 +218,8 @@ const std::vector<FuseConvertTransformationTestValues> testValuesWithConstant =
|
||||
{ ngraph::element::f32 },
|
||||
{1.f},
|
||||
{0.45f}
|
||||
}
|
||||
},
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::element::f32,
|
||||
@ -191,7 +227,8 @@ const std::vector<FuseConvertTransformationTestValues> testValuesWithConstant =
|
||||
{},
|
||||
{1.f},
|
||||
{0.45f}
|
||||
}
|
||||
},
|
||||
{}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
@ -12,7 +12,7 @@
|
||||
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <low_precision/fuse_fake_quantize.hpp>
|
||||
#include <low_precision/fake_quantize.hpp>
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
|
||||
@ -62,7 +62,7 @@ public:
|
||||
testValues.actual.fakeQuantizeOnData);
|
||||
|
||||
SimpleLowPrecisionTransformer transformer;
|
||||
transformer.add<ngraph::pass::low_precision::FuseFakeQuantizeTransformation, ngraph::opset1::FakeQuantize>(testValues.params);
|
||||
transformer.add<ngraph::pass::low_precision::FakeQuantizeTransformation, ngraph::opset1::FakeQuantize>(testValues.params);
|
||||
transformer.transform(actualFunction);
|
||||
|
||||
referenceFunction = ngraph::builder::subgraph::FuseFakeQuantizeFunction::get(
|
||||
|
@ -35,7 +35,6 @@
|
||||
|
||||
// cleanup transformations
|
||||
#include "low_precision/fuse_convert.hpp"
|
||||
#include "low_precision/fuse_fake_quantize.hpp"
|
||||
#include "low_precision/fuse_subtract_to_fake_quantize.hpp"
|
||||
#include "low_precision/fuse_multiply_to_fake_quantize.hpp"
|
||||
#include "low_precision/multiply_to_group_convolution.hpp"
|
||||
|
@ -117,10 +117,16 @@ public:
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Convert& convert) {
|
||||
if (convert.empty()) {
|
||||
return out << "{}";
|
||||
}
|
||||
return out << "_" << (convert.outPrecision != element::undefined ? convert.outPrecision.get_type_name() : "");
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Subtract& subtract) {
|
||||
if (subtract.empty()) {
|
||||
return out << "{}";
|
||||
}
|
||||
return out << "_" <<
|
||||
subtract.values << "_" <<
|
||||
subtract.outPrecision << "_" <<
|
||||
@ -132,6 +138,9 @@ inline std::ostream& operator<<(std::ostream& out, const DequantizationOperation
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations::Multiply& multiply) {
|
||||
if (multiply.empty()) {
|
||||
return out << "{}";
|
||||
}
|
||||
return out << "_" <<
|
||||
multiply.values << "_" <<
|
||||
multiply.outPrecision << "_" <<
|
||||
@ -142,6 +151,9 @@ inline std::ostream& operator<<(std::ostream& out, const DequantizationOperation
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const DequantizationOperations& data) {
|
||||
if (data.empty()) {
|
||||
return out << "{}";
|
||||
}
|
||||
return out << "_" << data.convert << "_" << data.subtract << "_" << data.multiply;
|
||||
}
|
||||
|
||||
|
@ -54,6 +54,9 @@ inline std::ostream& operator<<(std::ostream& os, const std::vector<float>& valu
|
||||
}
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnData& data) {
|
||||
if (data.empty()) {
|
||||
return out << "{}";
|
||||
}
|
||||
return out << "_" << data.quantizationLevel << data.constantShape << "_" << data.inputLowValues << "_" << data.inputHighValues <<
|
||||
"_" << data.outputLowValues << "_" << data.outputHighValues << "_" <<
|
||||
(data.outputPrecision == ngraph::element::undefined ? "" : data.outputPrecision.get_type_name());
|
||||
@ -89,6 +92,9 @@ public:
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const FakeQuantizeOnDataWithConstant& data) {
|
||||
if (data.empty()) {
|
||||
return out << "{}";
|
||||
}
|
||||
return out << "_" << data.quantizationLevel <<
|
||||
(data.constantShapes.empty() ? ngraph::Shape{} : data.constantShapes[0]) << "_" <<
|
||||
data.inputLowValues << "_" << data.inputHighValues << "_" <<
|
||||
|
@ -20,6 +20,7 @@ public:
|
||||
const ngraph::PartialShape& inputShape,
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||
const ngraph::builder::subgraph::FakeQuantizeOnData& fakeQuantize,
|
||||
const bool constInput);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getWithFQ(
|
||||
|
@ -16,6 +16,7 @@ std::shared_ptr<ngraph::Function> FuseConvertFunction::get(
|
||||
const ngraph::PartialShape& inputShape,
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||
const ngraph::builder::subgraph::FakeQuantizeOnData& fakeQuantize,
|
||||
const bool constInput) {
|
||||
std::shared_ptr<Node> parent;
|
||||
std::shared_ptr<op::Parameter> input;
|
||||
@ -28,14 +29,19 @@ std::shared_ptr<ngraph::Function> FuseConvertFunction::get(
|
||||
parent = input;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Node> dequantizationOp = makeDequantization(parent, dequantization);
|
||||
dequantizationOp->set_friendly_name("output");
|
||||
parent = makeDequantization(parent, dequantization);
|
||||
|
||||
if (!fakeQuantize.empty()) {
|
||||
parent = makeFakeQuantize(parent, fakeQuantize.outputPrecision, fakeQuantize);
|
||||
}
|
||||
|
||||
parent->set_friendly_name("output");
|
||||
|
||||
auto parameters = constInput ?
|
||||
ngraph::ParameterVector{}:
|
||||
ngraph::ParameterVector{ input };
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(dequantizationOp) };
|
||||
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(parent)};
|
||||
return std::make_shared<ngraph::Function>(results, parameters, "FuseConvertFunction");
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user