[LPT] POT support: absent convert fix & element-wise empty dequantization data (#3067)

This commit is contained in:
Edward Shogulin 2020-11-13 10:32:59 +03:00 committed by GitHub
parent 17c67ddc5f
commit 4a362bddc5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 451 additions and 121 deletions

View File

@ -69,11 +69,13 @@ bool EltwiseBaseTransformation::canBeTransformed(const TransformationContext& co
return false;
}
if (dequantization1.empty() && !is_type<opset1::Constant>(dequantization1.data.get_node_shared_ptr())) {
if ((dequantization1.data.get_node() == nullptr) ||
(dequantization1.empty() && !is_type<opset1::Constant>(dequantization1.data.get_node_shared_ptr()))) {
return false;
}
if (dequantization2.empty() && !is_type<opset1::Constant>(dequantization2.data.get_node_shared_ptr())) {
if ((dequantization2.data.get_node() == nullptr) ||
(dequantization2.empty() && !is_type<opset1::Constant>(dequantization2.data.get_node_shared_ptr()))) {
return false;
}

View File

@ -948,7 +948,10 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
auto parent = newOperation;
if (shouldConvert) {
parent = std::make_shared<DequantizationConvert>(parent, dequantization.convert->get_output_element_type(0));
const auto convertOutputPrecision = dequantization.convert != nullptr ?
dequantization.convert->get_output_element_type(0) :
dequantization.multiply->get_output_element_type(0);
parent = std::make_shared<DequantizationConvert>(parent, convertOutputPrecision);
ngraph::copy_runtime_info({ newOperation, parent }, parent);
}

View File

@ -0,0 +1,161 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "layer_transformation.hpp"
#include <string>
#include <sstream>
#include <memory>
#include <gtest/gtest.h>
#include <utility>
#include <transformations/utils/utils.hpp>
#include <transformations/init_node_info.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "simple_low_precision_transformer.hpp"
#include <low_precision/add.hpp>
#include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
using namespace testing;
using namespace ngraph::pass;
using namespace ngraph::builder::subgraph;
class ElementwiseWithMultiParentDequantizationTransformationTestValues {
public:
class Actual {
public:
ngraph::element::Type precision1;
ngraph::builder::subgraph::DequantizationOperations dequantization1;
ngraph::element::Type precision2;
ngraph::builder::subgraph::DequantizationOperations dequantization2;
};
class Expected {
public:
ngraph::element::Type precision1;
ngraph::builder::subgraph::DequantizationOperations dequantization1;
ngraph::element::Type precision2;
ngraph::builder::subgraph::DequantizationOperations dequantization2;
};
ngraph::element::Type precision;
ngraph::Shape inputShape;
ngraph::pass::low_precision::LayerTransformation::Params params;
Actual actual;
Expected expected;
};
template <typename T>
inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
os << "{ ";
for (size_t i = 0; i < values.size(); ++i) {
os << values[i];
if (i != (values.size() - 1ul)) {
os << ", ";
}
}
os << " }";
return os;
}
class ElementwiseWithMultiParentDequantizationTransformation :
public LayerTransformation,
public testing::WithParamInterface<ElementwiseWithMultiParentDequantizationTransformationTestValues> {
public:
void SetUp() override {
const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = GetParam();
actualFunction = ElementwiseWithMultiParentDequantizationFunction::get(
testValues.precision,
testValues.inputShape,
testValues.params,
testValues.actual.precision1,
testValues.actual.dequantization1,
testValues.actual.precision2,
testValues.actual.dequantization2);
SimpleLowPrecisionTransformer transform;
transform.add<ngraph::pass::low_precision::AddTransformation, ngraph::opset1::Add>(
low_precision::LayerTransformation::Params(testValues.params));
transform.transform(actualFunction);
referenceFunction = ElementwiseWithMultiParentDequantizationFunction::get(
testValues.precision,
testValues.inputShape,
testValues.params,
testValues.expected.precision1,
testValues.expected.dequantization1,
testValues.expected.precision2,
testValues.expected.dequantization2);
}
static std::string getTestCaseName(testing::TestParamInfo<ElementwiseWithMultiParentDequantizationTransformationTestValues> obj) {
const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = obj.param;
std::ostringstream result;
result <<
testValues.precision << "_" <<
testValues.inputShape << "_" <<
testValues.actual.precision1 << "_" <<
testValues.actual.dequantization1 << "_" <<
testValues.actual.precision2 << "_" <<
testValues.actual.dequantization2;
return result.str();
}
};
TEST_P(ElementwiseWithMultiParentDequantizationTransformation, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
ASSERT_TRUE(res.first) << res.second;
}
const std::vector<ElementwiseWithMultiParentDequantizationTransformationTestValues> addTransformationTestValues = {
// U8
{
ngraph::element::f32,
ngraph::Shape{1, 4, 16, 16},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{ {ngraph::element::f32}, { 7.f }, { 10.f }},
ngraph::element::u8,
{},
},
{
ngraph::element::u8,
{ {ngraph::element::f32}, { 7.f }, { 10.f }},
ngraph::element::u8,
{},
}
},
// U8
{
ngraph::element::f32,
ngraph::Shape{1, 4, 16, 16},
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{},
ngraph::element::u8,
{ {ngraph::element::f32}, { 7.f }, { 10.f }}
},
{
ngraph::element::u8,
{},
ngraph::element::u8,
{ {ngraph::element::f32}, { 7.f }, { 10.f }}
}
}
};
INSTANTIATE_TEST_CASE_P(
LPT,
ElementwiseWithMultiParentDequantizationTransformation,
::testing::ValuesIn(addTransformationTestValues),
ElementwiseWithMultiParentDequantizationTransformation::getTestCaseName);

View File

@ -17,94 +17,185 @@
#include "common_test_utils/ngraph_test_utils.hpp"
#include "simple_low_precision_transformer.hpp"
#include "ngraph_functions/low_precision_transformations/max_pool_function.hpp"
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
using namespace testing;
using namespace ngraph::pass;
class MaxPoolTransformationTestValues {
public:
low_precision::LayerTransformation::Params params;
std::vector<float> subtractValues;
std::vector<float> mutliplyValues;
class Actual {
public:
ngraph::element::Type precisionBeforeDequantization;
ngraph::builder::subgraph::DequantizationOperations dequantization1;
ngraph::builder::subgraph::DequantizationOperations dequantization2;
};
class Expected {
public:
ngraph::element::Type precisionBeforeDequantization;
ngraph::builder::subgraph::DequantizationOperations dequantization1;
ngraph::builder::subgraph::DequantizationOperations dequantization2;
};
ngraph::pass::low_precision::LayerTransformation::Params params;
Actual actual;
Expected expected;
};
typedef std::tuple<
ngraph::element::Type,
ngraph::Shape,
MaxPoolTransformationTestValues> MaxPoolTransformationParams;
class MaxPoolTransformation : public LayerTransformation, public testing::WithParamInterface<MaxPoolTransformationParams> {
public:
void SetUp() override {
const ngraph::element::Type precision = std::get<0>(GetParam());
const ngraph::Shape shape = std::get<1>(GetParam());
const MaxPoolTransformationTestValues testValues = std::get<2>(GetParam());
const ngraph::Shape shape = std::get<0>(GetParam());
const MaxPoolTransformationTestValues testValues = std::get<1>(GetParam());
actualFunction = ngraph::builder::subgraph::MaxPoolFunction::getOriginal(
precision,
actualFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
shape,
{
testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
testValues.subtractValues,
testValues.mutliplyValues
});
testValues.actual.precisionBeforeDequantization,
testValues.actual.dequantization1,
testValues.actual.dequantization2);
SimpleLowPrecisionTransformer transform;
transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
transform.transform(actualFunction);
referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::getReference(
precision,
referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
shape,
{
testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
testValues.subtractValues,
testValues.mutliplyValues
});
testValues.expected.precisionBeforeDequantization,
testValues.expected.dequantization1,
testValues.expected.dequantization2);
}
static std::string getTestCaseName(testing::TestParamInfo<MaxPoolTransformationParams> obj) {
const ngraph::element::Type precision = std::get<0>(obj.param);
const ngraph::Shape shape = std::get<1>(obj.param);
const MaxPoolTransformationTestValues testValues = std::get<2>(obj.param);
const ngraph::Shape shape = std::get<0>(obj.param);
const MaxPoolTransformationTestValues testValues = std::get<1>(obj.param);
std::ostringstream result;
result <<
LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
testValues.subtractValues.size() << "_" <<
testValues.mutliplyValues.size() << "_";
LayerTransformation::getTestCaseNameByParams(testValues.actual.precisionBeforeDequantization, shape, testValues.params) << "_" <<
testValues.actual.dequantization1 << "_" <<
testValues.actual.dequantization2 << "_" <<
testValues.expected.dequantization1 << "_" <<
testValues.expected.dequantization2 << "_";
return result.str();
}
};
TEST_P(MaxPoolTransformation, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
ASSERT_TRUE(res.first) << res.second;
}
const std::vector<ngraph::element::Type> precisions = {
ngraph::element::f32,
// ngraph::element::f16
};
const std::vector<ngraph::Shape> shapes = {
{ 1, 32, 72, 48 }
{ 1, 32, 72, 48 },
{ 4, 32, 72, 48 }
};
const std::vector<MaxPoolTransformationTestValues> testValues = {
{ LayerTransformation::createParamsU8I8(), { 128 }, { 0.02f } },
{ LayerTransformation::createParamsU8I8(), {}, { 0.02f } },
{ LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), { 128 }, { 0.02f } },
{ LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), {}, { 0.02f } },
{ LayerTransformation::createParamsI8I8(), { 128 }, { 0.02f } },
// Multiply
{
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{ {}, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }},
{}
},
{
ngraph::element::u8,
{},
{ ngraph::element::f32, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }}
}
},
// Subtract + Multiply
{
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{
{},
{ {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
{ {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
},
{}
},
{
ngraph::element::u8,
{},
{
ngraph::element::f32,
{ {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
{ {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
}
}
},
// Convert + Subtract + Multiply
{
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{ ngraph::element::f32, { 128 }, { 0.02f }},
{}
},
{
ngraph::element::u8,
{},
{ ngraph::element::f32, { 128 }, { 0.02f }}
}
},
// Convert + Subtract + Multiply
{
LayerTransformation::createParamsU8I8(),
{
ngraph::element::u8,
{ ngraph::element::f32, {}, { 0.02f }},
{}
},
{
ngraph::element::u8,
{},
{ ngraph::element::f32, {}, { 0.02f }}
}
},
// Convert + Subtract + Multiply
{
LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
{
ngraph::element::u8,
{ ngraph::element::f32, { 128 }, { 0.02f }},
{}
},
{
ngraph::element::u8,
{},
{ ngraph::element::f32, { 128 }, { 0.02f }}
}
},
// Convert + Subtract + Multiply
{
LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
{
ngraph::element::u8,
{ ngraph::element::f32, {}, { 0.02f }},
{}
},
{
ngraph::element::u8,
{},
{ ngraph::element::f32, {}, { 0.02f }}
}
}
};
INSTANTIATE_TEST_CASE_P(
LPT,
MaxPoolTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::ValuesIn(shapes),
::testing::ValuesIn(testValues)),
MaxPoolTransformation::getTestCaseName);

View File

@ -0,0 +1,71 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <ngraph/ngraph.hpp>
#include "functional_test_utils/low_precision_transformations/layer_transformation.hpp"
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
#include "ngraph_functions/subgraph_builders.hpp"
#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
namespace ngraph {
namespace builder {
namespace subgraph {
class AddActualValues {
public:
ngraph::element::Type precision1;
std::vector<float> subtractValues1;
std::vector<float> mutliplyValues1;
ngraph::element::Type precision2;
std::vector<float> subtractValues2;
std::vector<float> mutliplyValues2;
};
inline std::ostream& operator<<(std::ostream& out, const AddActualValues& values) {
return out <<
"_" << values.precision1 <<
"_subtract" << values.subtractValues1.size() <<
"_mutliply" << values.mutliplyValues1.size() <<
"_" << values.precision2 <<
"_subtract" << values.subtractValues2.size() <<
"_mutliply" << values.mutliplyValues2.size();
}
class AddExpectedValues {
public:
ngraph::element::Type precision1;
std::vector<float> subtractValues1;
std::vector<float> mutliplyValues1;
ngraph::element::Type precision2;
std::vector<float> mutliplyValuesAfter;
};
inline std::ostream& operator<<(std::ostream& out, const AddExpectedValues& values) {
return out <<
"_" << values.precision1 <<
"_subtract" << values.subtractValues1.size() <<
"_mutliply" << values.mutliplyValues1.size() <<
"_" << values.precision2 <<
"_mutliply" << values.mutliplyValuesAfter.size();
}
class ElementwiseWithMultiParentDequantizationFunction {
public:
static std::shared_ptr<ngraph::Function> get(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::pass::low_precision::LayerTransformation::Params& params,
const ngraph::element::Type& precision1,
const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
const ngraph::element::Type& precision2,
const ngraph::builder::subgraph::DequantizationOperations& dequantization2);
};
} // namespace subgraph
} // namespace builder
} // namespace ngraph

View File

@ -8,6 +8,7 @@
#include <ngraph/ngraph.hpp>
#include "common/fake_quantize_on_data.hpp"
#include "low_precision/layer_transformation.hpp"
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
namespace ngraph {
namespace builder {
@ -15,34 +16,16 @@ namespace subgraph {
class MaxPoolFunction {
public:
class ActualValues {
public:
ngraph::element::Type lowPrecision;
std::vector<float> subtractValues;
std::vector<float> mutliplyValues;
};
class ExpectedValues {
public:
ngraph::element::Type activationPrecision;
std::vector<float> subtractValues;
std::vector<float> mutliplyValues;
};
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type originalFunctionPrecision,
const ngraph::Shape& inputShape,
const ActualValues& values);
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::element::Type originalFunctionPrecision,
const ngraph::Shape& inputShape,
const FakeQuantizeOnData& fakeQuantizeOnData);
static std::shared_ptr<ngraph::Function> getReference(
const ngraph::element::Type originalFunctionPrecision,
static std::shared_ptr<ngraph::Function> get(
const ngraph::Shape& inputShape,
const ExpectedValues& values);
const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter);
};
} // namespace subgraph

View File

@ -0,0 +1,60 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
#include "low_precision/network_helper.hpp"
#include <ngraph/opsets/opset1.hpp>
#include "ngraph_functions/builders.hpp"
#include "ngraph_functions/subgraph_builders.hpp"
using namespace ngraph::pass::low_precision;
namespace ngraph {
namespace builder {
namespace subgraph {
std::shared_ptr<ngraph::Function> ElementwiseWithMultiParentDequantizationFunction::get(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::pass::low_precision::LayerTransformation::Params& params,
const ngraph::element::Type& precision1,
const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
const ngraph::element::Type& precision2,
const ngraph::builder::subgraph::DequantizationOperations& dequantization2) {
const auto input1_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
const auto input1_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
const std::shared_ptr<ngraph::Node> multiply1 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
DequantizationMultiply(
ngraph::op::TemporaryReplaceOutputType(input1_1, element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(input1_2, element::f32).get()),
std::vector<element::Type>{element::f32, element::f32},
std::vector<element::Type>{});
const std::shared_ptr<ngraph::Node> parent1 = dequantization1.empty() ? multiply1 : makeDequantization(multiply1, dequantization1);
const auto input2_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
const auto input2_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
const std::shared_ptr<ngraph::Node> multiply2 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
DequantizationMultiply(
ngraph::op::TemporaryReplaceOutputType(input2_1, element::f32).get(),
ngraph::op::TemporaryReplaceOutputType(input2_2, element::f32).get()),
std::vector<element::Type>{element::f32, element::f32},
std::vector<element::Type>{});
const std::shared_ptr<ngraph::Node> parent2 = dequantization2.empty() ? multiply2 : makeDequantization(multiply2, dequantization2);
const auto add = std::make_shared<ngraph::opset1::Add>(parent1, parent2);
add->set_friendly_name("output");
auto& rtInfo = add->get_rt_info();
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("add");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add) };
ngraph::ParameterVector parameters = { input1_1, input1_2, input2_1, input2_2 };
return std::make_shared<ngraph::Function>(results, parameters, "ElementwiseWithMultiParentDequantization");
}
} // namespace subgraph
} // namespace builder
} // namespace ngraph

View File

@ -14,41 +14,6 @@ namespace ngraph {
namespace builder {
namespace subgraph {
std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
const ngraph::element::Type originalFunctionPrecision,
const ngraph::Shape& inputShape,
const ActualValues& values) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(values.lowPrecision, ngraph::Shape(inputShape));
std::shared_ptr<ngraph::Node> parent = input;
const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::opset1::Convert>(parent, originalFunctionPrecision);
parent = convert;
if (!values.subtractValues.empty()) {
const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::opset1::Subtract>(
parent,
std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
parent = subtract;
}
const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::opset1::Multiply>(
parent,
std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
parent = multiply;
const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
parent,
Strides{ 1, 1 },
Shape{ 1, 1 },
Shape{ 0, 0 },
Shape{ 2, 2 },
op::RoundingType::FLOOR);
maxPool->set_friendly_name("output");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(maxPool) };
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
}
std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
const ngraph::element::Type originalFunctionPrecision,
const ngraph::Shape& inputShape,
@ -71,13 +36,16 @@ std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
}
std::shared_ptr<ngraph::Function> MaxPoolFunction::getReference(
const ngraph::element::Type originalFunctionPrecision,
std::shared_ptr<ngraph::Function> MaxPoolFunction::get(
const ngraph::Shape& inputShape,
const ExpectedValues& values) {
auto input = std::make_shared<ngraph::opset1::Parameter>(values.activationPrecision, ngraph::Shape(inputShape));
const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, ngraph::Shape(inputShape));
std::shared_ptr<ngraph::Node> parent = input;
parent = dequantizationBefore.empty() ? parent : makeDequantization(parent, dequantizationBefore);
const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
parent,
Strides{ 1, 1 },
@ -87,25 +55,16 @@ std::shared_ptr<ngraph::Function> MaxPoolFunction::getReference(
op::RoundingType::FLOOR);
parent = maxPool;
if (parent->get_output_element_type(0) != originalFunctionPrecision) {
const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::pass::low_precision::DequantizationConvert>(parent, originalFunctionPrecision);
parent = convert;
}
parent = dequantizationAfter.empty() ? maxPool : makeDequantization(maxPool, dequantizationAfter);
maxPool->set_friendly_name("maxPool");
if (!values.subtractValues.empty()) {
const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(
parent,
std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
parent = subtract;
}
const std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(parent);
const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(
parent,
std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
multiply->set_friendly_name("output");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
const std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
ngraph::ResultVector{ result },
std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
"MaxPoolTransformation");
return function;
}
} // namespace subgraph