[LPT] Correct handling of Dq after operations with several outputs (#4797)

* [LPT] Dq after Split/VariadicSplit fix

* [LPT][TESTS] Transformations after split tests
This commit is contained in:
Vladislav Golubev 2021-03-25 23:19:14 +03:00 committed by GitHub
parent 522ad39a48
commit 8b4837ea62
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 449 additions and 14 deletions

View File

@ -178,10 +178,10 @@ bool AddTransformation::transform(TransformationContext& context, ngraph::patter
}
// graph update
std::vector<std::shared_ptr<Node>> inputs{ {}, {} };
std::vector<Output<Node>> inputs{ {}, {} };
auto fullPathInput = dequantizationFullPath.convert == nullptr ? dequantizationFullPath.data : dequantizationFullPath.convert;
inputs[emptyPathIndex] = dequantizationEmptyPath.data.get_node_shared_ptr();
inputs[emptyPathIndex] = dequantizationEmptyPath.data;
inputs[fullPathIndex] = std::make_shared<DequantizationMultiply>(
newSubtractFullPathValues == nullptr ?
fullPathInput :

View File

@ -77,7 +77,7 @@ bool ClampTransformation::transform(TransformationContext& context, ngraph::patt
max += shift;
}
replacement = std::make_shared<ngraph::opset1::Clamp>(newClamp->get_input_node_shared_ptr(0), min, max);
replacement = std::make_shared<ngraph::opset1::Clamp>(newClamp->get_input_source_output(0), min, max);
}
replace_node(newClamp, replacement);

View File

@ -165,7 +165,7 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
if (is_type<opset1::Convert>(convolution->get_input_node_ptr(0))) {
auto newConvolution = convolution->clone_with_new_inputs({
convolution->get_input_node_ptr(0)->get_input_node_shared_ptr(0),
convolution->get_input_node_ptr(0)->get_input_source_output(0),
convolution->get_input_node_shared_ptr(1) });
replace_node(convolution, newConvolution);
convolution = newConvolution;
@ -253,7 +253,7 @@ bool ConvolutionTransformation::transform(TransformationContext &context, ngraph
std::shared_ptr<Node> childNode = reshapeFromWeights == nullptr ? convolution : reshapeFromWeights;
auto newConvolution = convolution->clone_with_new_inputs({
convolution->get_input_node_shared_ptr(0),
convolution->get_input_source_output(0),
childNode.get() == convolution.get() ?
convolution->get_input_node_ptr(1)->get_input_node_shared_ptr(0) :
childNode->copy_with_new_inputs({convertFromWeights->input_value(0), childNode->input_value(1)})});

View File

@ -182,8 +182,10 @@ std::shared_ptr<opset1::FakeQuantize> FakeQuantizeTransformation::fuseElementwis
return nullptr;
}
const auto data = fq::getData(eltwise);
const size_t outputIdx = NetworkHelper::getParentOutputIndex(data, eltwise);
std::shared_ptr<opset1::FakeQuantize> newFakeQuantize = as_type_ptr<opset1::FakeQuantize>(fakeQuantize->clone_with_new_inputs({
fq::getData(eltwise),
data->output(outputIdx),
inputLowConst_f32,
inputHightConst_f32,
fold<opset1::Convert>(fakeQuantize->input_value(3), deqPrecision),

View File

@ -63,10 +63,10 @@ bool MultiplyTransformation::transform(TransformationContext& context, ngraph::p
return false;
}
auto multiplyParent = multiply->get_input_node_shared_ptr(multiplyBranch.first);
auto constParent = multiply->get_input_node_shared_ptr(multiplyBranch.first == 0 ? 1 : 0);
auto multiplyParentParent = multiplyParent->get_input_node_shared_ptr(multiplyBranch.second);
auto multiplyParentConst = multiplyParent->get_input_node_shared_ptr(multiplyBranch.second == 0 ? 1 : 0);
auto multiplyParent = multiply->get_input_source_output(multiplyBranch.first);
auto constParent = multiply->get_input_source_output(multiplyBranch.first == 0 ? 1 : 0);
auto multiplyParentParent = multiplyParent.get_node_shared_ptr()->get_input_source_output(multiplyBranch.second);
auto multiplyParentConst = multiplyParent.get_node_shared_ptr()->get_input_source_output(multiplyBranch.second == 0 ? 1 : 0);
newMultiply = std::make_shared<op::TypeRelaxed<opset1::Multiply>>(
std::vector<ngraph::element::Type>{ element::f32, element::f32 },
@ -78,7 +78,7 @@ bool MultiplyTransformation::transform(TransformationContext& context, ngraph::p
fold<opset1::Convert>(constParent, element::f32)),
element::f32).get());
NetworkHelper::copyInfo(multiplyParent, newMultiply);
NetworkHelper::copyInfo(multiplyParent.get_node_shared_ptr(), newMultiply);
NetworkHelper::copyInfo(multiply, newMultiply);
if (!FakeQuantizeDequantization::checkElementwise(newMultiply)) {
@ -118,7 +118,7 @@ bool MultiplyTransformation::transform(TransformationContext& context, ngraph::p
// after : Y = (SC1' * (X1 - SH1)) * (X2) , where :
// SC1' = SC1 * SC2
std::shared_ptr<Node> newMultiplyValuesFullPath = fold<opset1::Multiply>(multiplyValuesEmptyPath, multiplyValuesFullPath);
std::vector<Output<Node>> inputs{ {}, {} };
OutputVector inputs{ {}, {} };
inputs[emptyPathIndex] = dequantizationEmptyPath.data;
inputs[fullPathIndex] = std::make_shared<DequantizationMultiply>(
dequantizationFullPath.subtract == nullptr ?

View File

@ -1102,7 +1102,7 @@ FakeQuantizeDequantization NetworkHelper::getDequantization(const std::shared_pt
return 1ul;
};
Output<Node> dataNode = inPlace ? node : node->input_value(parentIndex);
Output<Node> dataNode = inPlace ? node->output(0) : node->input_value(parentIndex);
const std::shared_ptr<ngraph::opset1::Multiply> multiply = as_type_ptr<ngraph::opset1::Multiply>(dataNode.get_node_shared_ptr());
std::shared_ptr<opset1::Constant> multiplyConstant;

View File

@ -218,6 +218,7 @@ LowPrecisionTransformations LowPrecisionTransformer::getAllTransformations(const
add<FakeQuantizeTransformation, opset1::FakeQuantize>(params).
add<GroupConvolutionTransformation, opset1::GroupConvolution>(params).
add<InterpolateTransformation, opset1::Interpolate>(params).
add<InterpolateTransformation, opset4::Interpolate>(params).
add<MatMulTransformation, opset1::MatMul>(params).
add<MaxPoolTransformation, opset1::MaxPool>(params).
add<MultiplyTransformation, opset1::Multiply>(params).
@ -231,7 +232,6 @@ LowPrecisionTransformations LowPrecisionTransformer::getAllTransformations(const
add<StridedSliceTransformation, opset1::StridedSlice>(params).
add<TransposeTransformation, opset1::Transpose>(params).
add<UnsqueezeTransformation, opset1::Unsqueeze>(params).
add<InterpolateTransformation, opset4::Interpolate>(params).
addCleanup<FoldConvertTransformation, opset1::Subtract>(params).
addCleanup<FuseConvertTransformation, opset1::Multiply>(params).

View File

@ -0,0 +1,220 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "layer_transformation.hpp"
#include <string>
#include <sstream>
#include <memory>
#include <gtest/gtest.h>
#include <transformations/utils/utils.hpp>
// general transformations
#include "low_precision/add.hpp"
#include "low_precision/avg_pool.hpp"
#include "low_precision/clamp.hpp"
#include "low_precision/convolution.hpp"
#include "low_precision/depth_to_space.hpp"
#include "low_precision/fake_quantize.hpp"
#include "low_precision/interpolate.hpp"
#include "low_precision/mat_mul.hpp"
#include "low_precision/max_pool.hpp"
#include "low_precision/multiply.hpp"
#include "low_precision/mvn.hpp"
#include "low_precision/normalize_l2.hpp"
#include "low_precision/prelu.hpp"
#include "low_precision/reshape.hpp"
#include "low_precision/relu.hpp"
#include "low_precision/squeeze.hpp"
#include "low_precision/subtract.hpp"
#include "low_precision/strided_slice.hpp"
#include "low_precision/transpose.hpp"
#include "low_precision/unsqueeze.hpp"
// cleanup transformations
#include "low_precision/fuse_convert.hpp"
#include "low_precision/fuse_fake_quantize.hpp"
#include "low_precision/fuse_subtract_to_fake_quantize.hpp"
#include "low_precision/fuse_multiply_to_fake_quantize.hpp"
#include "low_precision/multiply_to_group_convolution.hpp"
#include "low_precision/subtract_multiply_to_multiply_add.hpp"
#include "lpt_ngraph_functions/transformations_after_split_function.hpp"
#include "common_test_utils/ngraph_test_utils.hpp"
#include "simple_low_precision_transformer.hpp"
namespace {
using namespace testing;
using namespace ngraph;
using namespace ngraph::pass;
SimpleLowPrecisionTransformer getTransformerWithTransformationByName(
const ngraph::pass::low_precision::LayerTransformation::Params& params,
std::string name) {
using namespace pass::low_precision;
SimpleLowPrecisionTransformer transformer;
if (name == "AddTransformation") {
transformer.add<AddTransformation, ngraph::opset1::Add>(params);
return transformer;
}
if (name == "AvgPoolTransformation") {
transformer.add<AvgPoolTransformation, opset1::AvgPool>(params);
return transformer;
}
if (name == "ClampTransformation") {
transformer.add<ClampTransformation, opset1::Clamp>(params);
return transformer;
}
if (name == "ConvolutionTransformation") {
transformer.add<ConvolutionTransformation, opset1::Convolution>(params);
return transformer;
}
if (name == "DepthToSpaceTransformation") {
transformer.add<DepthToSpaceTransformation, opset1::DepthToSpace>(params);
return transformer;
}
if (name == "FakeQuantizeTransformation") {
transformer.add<FakeQuantizeTransformation, opset1::FakeQuantize>(params);
return transformer;
}
if (name == "InterpolateTransformation") {
transformer.add<InterpolateTransformation, ngraph::opset1::Interpolate>(params);
return transformer;
}
if (name == "MatMulTransformation") {
transformer.add<MatMulTransformation, ngraph::opset1::MatMul>(params);
return transformer;
}
if (name == "MaxPoolTransformation") {
transformer.add<MaxPoolTransformation, ngraph::opset1::MaxPool>(params);
return transformer;
}
if (name == "MultiplyTransformation") {
transformer.add<MultiplyTransformation, ngraph::opset1::Multiply>(params);
return transformer;
}
if (name == "MVNTransformation") {
transformer.add<MVNTransformation, ngraph::op::MVN>(params);
return transformer;
}
if (name == "NormalizeL2Transformation") {
transformer.add<NormalizeL2Transformation, ngraph::opset1::NormalizeL2>(params);
return transformer;
}
if (name == "PReluTransformation") {
transformer.add<PReluTransformation, ngraph::opset1::PRelu>(params);
return transformer;
}
if (name == "ReluTransformation") {
transformer.add<ReluTransformation, ngraph::opset1::Relu>(params);
return transformer;
}
if (name == "ReshapeTransformation") {
transformer.add<ReshapeTransformation, ngraph::opset1::Reshape>(params);
return transformer;
}
if (name == "SqueezeTransformation") {
transformer.add<SqueezeTransformation, ngraph::opset1::Squeeze>(params);
return transformer;
}
if (name == "StridedSliceTransformation") {
transformer.add<StridedSliceTransformation, ngraph::opset1::StridedSlice>(params);
return transformer;
}
if (name == "TransposeTransformation") {
transformer.add<TransposeTransformation, ngraph::opset1::Transpose>(params);
return transformer;
}
if (name == "UnsqueezeTransformation") {
transformer.add<UnsqueezeTransformation, ngraph::opset1::Unsqueeze>(params);
return transformer;
}
if (name == "FuseConvertTransformation") {
transformer.add<FuseConvertTransformation, ngraph::opset1::Multiply>(params);
return transformer;
}
if (name == "FuseSubtractToFakeQuantizeTransformation") {
transformer.add<FuseSubtractToFakeQuantizeTransformation, ngraph::opset1::Subtract>(params);
return transformer;
}
if (name == "FuseMultiplyToFakeQuantizeTransformation") {
transformer.add<FuseMultiplyToFakeQuantizeTransformation, ngraph::opset1::Multiply>(params);
return transformer;
}
if (name == "MultiplyToGroupConvolutionTransformation") {
transformer.add<MultiplyToGroupConvolutionTransformation, ngraph::opset1::Multiply>(params);
return transformer;
}
if (name == "SubtractMultiplyToMultiplyAddTransformation") {
transformer.add<SubtractMultiplyToMultiplyAddTransformation, ngraph::opset1::Multiply>(params);
return transformer;
}
throw std::runtime_error("unexpected transformation name");
}
class TransformationsAfterSplitTransformation : public LayerTransformation, public testing::WithParamInterface<std::string> {
public:
void SetUp() override {
const auto layerName = GetParam();
function = ngraph::builder::subgraph::TransformationsAfterSplitFunction::get(layerName);
function->validate_nodes_and_infer_types();
}
static std::string getTestCaseName(testing::TestParamInfo<std::string> obj) {
const auto layerName = obj.param;
std::ostringstream result;
result << "additional_layer_name_" << layerName;
return result.str();
}
protected:
std::shared_ptr<ngraph::Function> function;
};
TEST_P(TransformationsAfterSplitTransformation, Run) {
const std::string layerName = GetParam();
const auto params = LayerTransformation::createParamsU8I8();
SimpleLowPrecisionTransformer transformer = getTransformerWithTransformationByName(params, layerName);
ASSERT_NO_THROW(transformer.transform(function));
}
const std::vector<std::string> transformationNames = {
"AddTransformation",
"AvgPoolTransformation",
"ClampTransformation",
"ConvolutionTransformation",
"DepthToSpaceTransformation",
"FakeQuantizeTransformation",
"InterpolateTransformation",
"MatMulTransformation",
"MaxPoolTransformation",
"MultiplyTransformation",
"MVNTransformation",
"NormalizeL2Transformation",
"PReluTransformation",
"ReluTransformation",
"ReshapeTransformation",
"SqueezeTransformation",
"StridedSliceTransformation",
"TransposeTransformation",
"UnsqueezeTransformation",
"FuseConvertTransformation",
"FuseSubtractToFakeQuantizeTransformation",
"FuseMultiplyToFakeQuantizeTransformation",
"MultiplyToGroupConvolutionTransformation",
"SubtractMultiplyToMultiplyAddTransformation",
};
INSTANTIATE_TEST_CASE_P(
smoke_LPT,
TransformationsAfterSplitTransformation,
::testing::ValuesIn(transformationNames),
TransformationsAfterSplitTransformation::getTestCaseName);
} // namespace

View File

@ -0,0 +1,26 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <memory>
#include <ngraph/ngraph.hpp>
namespace ngraph {
namespace builder {
namespace subgraph {
class TransformationsAfterSplitFunction {
public:
static std::shared_ptr<Function> get(const std::string transformationName);
static std::shared_ptr<Node> getLayerByTransformationName(
const std::string transformationName,
const Output<Node> parent);
};
} // namespace subgraph
} // namespace builder
} // namespace ngraph

View File

@ -0,0 +1,187 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "lpt_ngraph_functions/transformations_after_split_function.hpp"
#include <string>
#include <ngraph/opsets/opset1.hpp>
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
#include "lpt_ngraph_functions/common/builders.hpp"
namespace ngraph {
namespace builder {
namespace subgraph {
std::shared_ptr<Function> TransformationsAfterSplitFunction::get(const std::string transformationName) {
const auto input = std::make_shared<opset1::Parameter>(element::u8, Shape{ 1, 3, 16, 16 });
const size_t outputSize = 2ul;
const auto axis = opset1::Constant::create(element::i64, Shape{}, { 2 });
const auto splitLength = opset1::Constant::create(element::i64, Shape{ outputSize }, { 8, 8 });
const auto variadicSplit = std::make_shared<opset1::VariadicSplit>(input, axis, splitLength);
ResultVector results;
for (size_t i = 0; i < outputSize; ++i) {
const auto additionalLayer = getLayerByTransformationName(transformationName, variadicSplit->output(i));
results.push_back(std::make_shared<opset1::Result>(additionalLayer));
}
const auto function = std::make_shared<ngraph::Function>(
results,
ngraph::ParameterVector{ input },
"VariadicSplitAndAdditionalLayerTransformation");
return function;
}
std::shared_ptr<Node> TransformationsAfterSplitFunction::getLayerByTransformationName(
const std::string transformationName,
const Output<Node> parent) {
if (transformationName == "AddTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
const auto addConstant = opset1::Constant::create(element::f32, Shape{}, { 128.f });
return std::make_shared<opset1::Add>(dequantization, addConstant);
}
if (transformationName == "AvgPoolTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
return std::make_shared<ngraph::opset1::AvgPool>(
dequantization,
Strides{ 1, 1 },
Shape{ 1, 1 },
Shape{ 0, 0 },
Shape{ 2, 2 },
true,
op::RoundingType::FLOOR);
}
if (transformationName == "ClampTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
return std::make_shared<opset1::Clamp>(dequantization, 0.0, 6.0);
}
if (transformationName == "ConvolutionTransformation") {
const auto dequantizationOnData = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
const auto weights = opset1::Constant::create(element::i8, Shape{ 3, 3, 1, 1 }, { 2 });
const auto dequantizationOnWeights = makeDequantization(weights, { {element::f32}, {}, {0.3f} });
return std::make_shared<opset1::Convolution>(
dequantizationOnData,
dequantizationOnWeights,
Strides{ 1, 1 },
CoordinateDiff{ 0, 0 },
CoordinateDiff{ 0, 0 },
Strides{ 1, 1 });
}
if (transformationName == "DepthToSpaceTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
return std::make_shared<opset1::DepthToSpace>(dequantization, opset1::DepthToSpace::DepthToSpaceMode::BLOCKS_FIRST, 3);
}
if (transformationName == "FakeQuantizeTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
return makeFakeQuantize(dequantization, element::f32, { 256, Shape{}, { 0.f }, { 255.f }, { 0.f }, { 127.f } });
}
if (transformationName == "InterpolateTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
const auto outShape = opset1::Constant::create(element::i64, Shape{ 4 }, { 1, 4, 32, 32 });
op::v0::InterpolateAttrs attributes;
attributes.axes = AxisSet{ 2, 3 };
attributes.mode = "nearest";
attributes.align_corners = false;
attributes.antialias = false;
attributes.pads_begin = std::vector<size_t>{ 0ul };
attributes.pads_end = std::vector<size_t>{ 0ul };
return std::make_shared<opset1::Interpolate>(dequantization, outShape, attributes);
}
if (transformationName == "MatMulTransformation") {
const auto dequantizationOnData = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
const auto weights = opset1::Constant::create(element::i8, Shape{ 16, 16 }, { 2 });
const auto dequantizationOnWeights = makeDequantization(weights, { {element::f32}, {}, { 0.3f } });
return std::make_shared<opset1::MatMul>(dequantizationOnData, dequantizationOnWeights);
}
if (transformationName == "MaxPoolTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
return std::make_shared<ngraph::opset1::MaxPool>(
dequantization,
Strides{ 1, 1 },
Shape{ 1, 1 },
Shape{ 0, 0 },
Shape{ 2, 2 });
}
if (transformationName == "MultiplyTransformation") {
const auto dequantization = makeDequantization(parent, { {}, {}, {{ 2.f }, element::f32, {}} });
return makeDequantization(dequantization, { {}, {}, { 0.2f } });
}
if (transformationName == "MVNTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
return std::make_shared<ngraph::op::MVN>(dequantization, AxisSet{ 2, 3 });
}
if (transformationName == "NormalizeL2Transformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
const auto axesNode = opset1::Constant::create(element::u64, ngraph::Shape{ 3 }, { 1, 2, 3 });
return std::make_shared<ngraph::opset1::NormalizeL2>(dequantization, axesNode, 1e-6, ngraph::op::EpsMode::ADD);
}
if (transformationName == "PReluTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
const auto slope = std::make_shared<ngraph::opset1::Constant>(element::f32, Shape{}, std::vector<float> { 0.1f });
return std::make_shared<ngraph::opset1::PRelu>(dequantization, slope);
}
if (transformationName == "ReluTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
return std::make_shared<ngraph::opset1::Relu>(dequantization);
}
if (transformationName == "ReshapeTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
const auto reshapeConst = opset1::Constant::create(element::i64, ngraph::Shape{ 3 }, { 1, 3, -1 });
return std::make_shared<opset1::Reshape>(dequantization, reshapeConst, false);
}
if (transformationName == "SqueezeTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
const auto squeezeConst = opset1::Constant::create(element::i64, ngraph::Shape{ 1 }, { 0 });
return std::make_shared<opset1::Squeeze>(dequantization, squeezeConst);
}
if (transformationName == "StridedSliceTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
std::vector<int64_t> mask{ 1, 0, 1, 1 };
const auto beginParam = opset1::Constant::create(element::i64, Shape{ 4 }, { 0, 0, 0, 0 });
const auto endParam = opset1::Constant::create(element::i64, Shape{ 4 }, { 1, 2, 1, 1 });
const auto stridesParam = opset1::Constant::create(element::i64, Shape{ 4 }, { 1, 1, 1, 1 });
return std::make_shared<ngraph::opset1::StridedSlice>(dequantization, beginParam, endParam, stridesParam, mask, mask);
}
if (transformationName == "TransposeTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
const auto transposeConstant = opset1::Constant::create(element::i64, Shape{ 4 }, { 0, 1, 3, 2 });
return std::make_shared<ngraph::opset1::Transpose>(dequantization, transposeConstant);
}
if (transformationName == "UnsqueezeTransformation") {
const auto dequantization = makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
const auto unsqueezeConst = opset1::Constant::create(element::i64, ngraph::Shape{ 1 }, { 0 });
return std::make_shared<opset1::Unsqueeze>(dequantization, unsqueezeConst);
}
if (transformationName == "FuseConvertTransformation") {
return makeDequantization(parent, { {element::f32}, {}, { 0.1f } });
}
if (transformationName == "FuseSubtractToFakeQuantizeTransformation") {
const auto fakeQuantize = makeFakeQuantize(parent, element::f32, { 256, Shape{}, { 0.f }, { 255.f }, { 0.f }, { 127.f } });
return makeDequantization(fakeQuantize, { {}, {{ 128.f }, element::f32, {}}, {} });
}
if (transformationName == "FuseMultiplyToFakeQuantizeTransformation") {
const auto fakeQuantize = makeFakeQuantize(parent, element::f32, { 256, Shape{}, { 0.f }, { 255.f }, { 0.f }, { 127.f } });
return makeDequantization(fakeQuantize, { {}, {}, {{ 2.f }, element::f32, {}} });
}
if (transformationName == "MultiplyToGroupConvolutionTransformation") {
return makeDequantization(parent, { {}, {{ 128.f }, element::f32, {}}, { 2.f } });
}
if (transformationName == "SubtractMultiplyToMultiplyAddTransformation") {
return makeDequantization(parent, { {}, {{ 128.f }, element::f32, {}}, { 2.f } });
}
throw std::runtime_error("unexpected additional layer name");
}
} // namespace subgraph
} // namespace builder
} // namespace ngraph