[LPT] POT support: absent convert fix & element-wise empty dequantization data (#3067)
This commit is contained in:
parent
17c67ddc5f
commit
4a362bddc5
@ -69,11 +69,13 @@ bool EltwiseBaseTransformation::canBeTransformed(const TransformationContext& co
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dequantization1.empty() && !is_type<opset1::Constant>(dequantization1.data.get_node_shared_ptr())) {
|
||||
if ((dequantization1.data.get_node() == nullptr) ||
|
||||
(dequantization1.empty() && !is_type<opset1::Constant>(dequantization1.data.get_node_shared_ptr()))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (dequantization2.empty() && !is_type<opset1::Constant>(dequantization2.data.get_node_shared_ptr())) {
|
||||
if ((dequantization2.data.get_node() == nullptr) ||
|
||||
(dequantization2.empty() && !is_type<opset1::Constant>(dequantization2.data.get_node_shared_ptr()))) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
@ -948,7 +948,10 @@ NetworkHelper::InsertDequantizationResult NetworkHelper::moveDequantizationAfter
|
||||
|
||||
auto parent = newOperation;
|
||||
if (shouldConvert) {
|
||||
parent = std::make_shared<DequantizationConvert>(parent, dequantization.convert->get_output_element_type(0));
|
||||
const auto convertOutputPrecision = dequantization.convert != nullptr ?
|
||||
dequantization.convert->get_output_element_type(0) :
|
||||
dequantization.multiply->get_output_element_type(0);
|
||||
parent = std::make_shared<DequantizationConvert>(parent, convertOutputPrecision);
|
||||
ngraph::copy_runtime_info({ newOperation, parent }, parent);
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,161 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "layer_transformation.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <utility>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "simple_low_precision_transformer.hpp"
|
||||
|
||||
#include <low_precision/add.hpp>
|
||||
#include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
|
||||
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph::pass;
|
||||
using namespace ngraph::builder::subgraph;
|
||||
|
||||
class ElementwiseWithMultiParentDequantizationTransformationTestValues {
|
||||
public:
|
||||
class Actual {
|
||||
public:
|
||||
ngraph::element::Type precision1;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization1;
|
||||
ngraph::element::Type precision2;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization2;
|
||||
};
|
||||
|
||||
class Expected {
|
||||
public:
|
||||
ngraph::element::Type precision1;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization1;
|
||||
ngraph::element::Type precision2;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization2;
|
||||
};
|
||||
|
||||
ngraph::element::Type precision;
|
||||
ngraph::Shape inputShape;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
Actual actual;
|
||||
Expected expected;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline std::ostream& operator<<(std::ostream& os, const std::vector<T>& values) {
|
||||
os << "{ ";
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
os << values[i];
|
||||
if (i != (values.size() - 1ul)) {
|
||||
os << ", ";
|
||||
}
|
||||
}
|
||||
os << " }";
|
||||
return os;
|
||||
}
|
||||
|
||||
class ElementwiseWithMultiParentDequantizationTransformation :
|
||||
public LayerTransformation,
|
||||
public testing::WithParamInterface<ElementwiseWithMultiParentDequantizationTransformationTestValues> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = GetParam();
|
||||
|
||||
actualFunction = ElementwiseWithMultiParentDequantizationFunction::get(
|
||||
testValues.precision,
|
||||
testValues.inputShape,
|
||||
testValues.params,
|
||||
testValues.actual.precision1,
|
||||
testValues.actual.dequantization1,
|
||||
testValues.actual.precision2,
|
||||
testValues.actual.dequantization2);
|
||||
|
||||
SimpleLowPrecisionTransformer transform;
|
||||
transform.add<ngraph::pass::low_precision::AddTransformation, ngraph::opset1::Add>(
|
||||
low_precision::LayerTransformation::Params(testValues.params));
|
||||
transform.transform(actualFunction);
|
||||
|
||||
referenceFunction = ElementwiseWithMultiParentDequantizationFunction::get(
|
||||
testValues.precision,
|
||||
testValues.inputShape,
|
||||
testValues.params,
|
||||
testValues.expected.precision1,
|
||||
testValues.expected.dequantization1,
|
||||
testValues.expected.precision2,
|
||||
testValues.expected.dequantization2);
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<ElementwiseWithMultiParentDequantizationTransformationTestValues> obj) {
|
||||
const ElementwiseWithMultiParentDequantizationTransformationTestValues testValues = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result <<
|
||||
testValues.precision << "_" <<
|
||||
testValues.inputShape << "_" <<
|
||||
testValues.actual.precision1 << "_" <<
|
||||
testValues.actual.dequantization1 << "_" <<
|
||||
testValues.actual.precision2 << "_" <<
|
||||
testValues.actual.dequantization2;
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ElementwiseWithMultiParentDequantizationTransformation, CompareFunctions) {
|
||||
actualFunction->validate_nodes_and_infer_types();
|
||||
auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
const std::vector<ElementwiseWithMultiParentDequantizationTransformationTestValues> addTransformationTestValues = {
|
||||
// U8
|
||||
{
|
||||
ngraph::element::f32,
|
||||
ngraph::Shape{1, 4, 16, 16},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{ {ngraph::element::f32}, { 7.f }, { 10.f }},
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{ {ngraph::element::f32}, { 7.f }, { 10.f }},
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
}
|
||||
},
|
||||
// U8
|
||||
{
|
||||
ngraph::element::f32,
|
||||
ngraph::Shape{1, 4, 16, 16},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{ {ngraph::element::f32}, { 7.f }, { 10.f }}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{ {ngraph::element::f32}, { 7.f }, { 10.f }}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
LPT,
|
||||
ElementwiseWithMultiParentDequantizationTransformation,
|
||||
::testing::ValuesIn(addTransformationTestValues),
|
||||
ElementwiseWithMultiParentDequantizationTransformation::getTestCaseName);
|
@ -17,94 +17,185 @@
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "simple_low_precision_transformer.hpp"
|
||||
#include "ngraph_functions/low_precision_transformations/max_pool_function.hpp"
|
||||
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
|
||||
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph::pass;
|
||||
|
||||
class MaxPoolTransformationTestValues {
|
||||
public:
|
||||
low_precision::LayerTransformation::Params params;
|
||||
std::vector<float> subtractValues;
|
||||
std::vector<float> mutliplyValues;
|
||||
class Actual {
|
||||
public:
|
||||
ngraph::element::Type precisionBeforeDequantization;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization1;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization2;
|
||||
};
|
||||
|
||||
class Expected {
|
||||
public:
|
||||
ngraph::element::Type precisionBeforeDequantization;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization1;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization2;
|
||||
};
|
||||
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
Actual actual;
|
||||
Expected expected;
|
||||
};
|
||||
|
||||
typedef std::tuple<
|
||||
ngraph::element::Type,
|
||||
ngraph::Shape,
|
||||
MaxPoolTransformationTestValues> MaxPoolTransformationParams;
|
||||
|
||||
class MaxPoolTransformation : public LayerTransformation, public testing::WithParamInterface<MaxPoolTransformationParams> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const ngraph::element::Type precision = std::get<0>(GetParam());
|
||||
const ngraph::Shape shape = std::get<1>(GetParam());
|
||||
const MaxPoolTransformationTestValues testValues = std::get<2>(GetParam());
|
||||
const ngraph::Shape shape = std::get<0>(GetParam());
|
||||
const MaxPoolTransformationTestValues testValues = std::get<1>(GetParam());
|
||||
|
||||
actualFunction = ngraph::builder::subgraph::MaxPoolFunction::getOriginal(
|
||||
precision,
|
||||
actualFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
|
||||
shape,
|
||||
{
|
||||
testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
|
||||
testValues.subtractValues,
|
||||
testValues.mutliplyValues
|
||||
});
|
||||
testValues.actual.precisionBeforeDequantization,
|
||||
testValues.actual.dequantization1,
|
||||
testValues.actual.dequantization2);
|
||||
|
||||
SimpleLowPrecisionTransformer transform;
|
||||
transform.add<ngraph::pass::low_precision::MaxPoolTransformation, ngraph::opset1::MaxPool>(testValues.params);
|
||||
transform.transform(actualFunction);
|
||||
|
||||
referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::getReference(
|
||||
precision,
|
||||
referenceFunction = ngraph::builder::subgraph::MaxPoolFunction::get(
|
||||
shape,
|
||||
{
|
||||
testValues.params.updatePrecisions ? testValues.params.precisionsOnActivations[0] : precision,
|
||||
testValues.subtractValues,
|
||||
testValues.mutliplyValues
|
||||
});
|
||||
testValues.expected.precisionBeforeDequantization,
|
||||
testValues.expected.dequantization1,
|
||||
testValues.expected.dequantization2);
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<MaxPoolTransformationParams> obj) {
|
||||
const ngraph::element::Type precision = std::get<0>(obj.param);
|
||||
const ngraph::Shape shape = std::get<1>(obj.param);
|
||||
const MaxPoolTransformationTestValues testValues = std::get<2>(obj.param);
|
||||
const ngraph::Shape shape = std::get<0>(obj.param);
|
||||
const MaxPoolTransformationTestValues testValues = std::get<1>(obj.param);
|
||||
|
||||
std::ostringstream result;
|
||||
result <<
|
||||
LayerTransformation::getTestCaseNameByParams(precision, shape, testValues.params) << "_" <<
|
||||
testValues.subtractValues.size() << "_" <<
|
||||
testValues.mutliplyValues.size() << "_";
|
||||
LayerTransformation::getTestCaseNameByParams(testValues.actual.precisionBeforeDequantization, shape, testValues.params) << "_" <<
|
||||
testValues.actual.dequantization1 << "_" <<
|
||||
testValues.actual.dequantization2 << "_" <<
|
||||
testValues.expected.dequantization1 << "_" <<
|
||||
testValues.expected.dequantization2 << "_";
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(MaxPoolTransformation, CompareFunctions) {
|
||||
actualFunction->validate_nodes_and_infer_types();
|
||||
auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
|
||||
auto res = compare_functions(referenceFunction, actualFunction, true, false, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
const std::vector<ngraph::element::Type> precisions = {
|
||||
ngraph::element::f32,
|
||||
// ngraph::element::f16
|
||||
};
|
||||
|
||||
const std::vector<ngraph::Shape> shapes = {
|
||||
{ 1, 32, 72, 48 }
|
||||
{ 1, 32, 72, 48 },
|
||||
{ 4, 32, 72, 48 }
|
||||
};
|
||||
|
||||
const std::vector<MaxPoolTransformationTestValues> testValues = {
|
||||
{ LayerTransformation::createParamsU8I8(), { 128 }, { 0.02f } },
|
||||
{ LayerTransformation::createParamsU8I8(), {}, { 0.02f } },
|
||||
{ LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), { 128 }, { 0.02f } },
|
||||
{ LayerTransformation::createParamsU8I8().setUpdatePrecisions(false), {}, { 0.02f } },
|
||||
{ LayerTransformation::createParamsI8I8(), { 128 }, { 0.02f } },
|
||||
// Multiply
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{ {}, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }},
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
{ ngraph::element::f32, {}, { {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }}
|
||||
}
|
||||
},
|
||||
// Subtract + Multiply
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{
|
||||
{},
|
||||
{ {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
|
||||
{ {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
|
||||
},
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{ {128.f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 },
|
||||
{ {0.02f}, ngraph::element::f32, {}, true, 1, ngraph::element::f32 }
|
||||
}
|
||||
}
|
||||
},
|
||||
// Convert + Subtract + Multiply
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, { 128 }, { 0.02f }},
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
{ ngraph::element::f32, { 128 }, { 0.02f }}
|
||||
}
|
||||
},
|
||||
// Convert + Subtract + Multiply
|
||||
{
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.02f }},
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
{ ngraph::element::f32, {}, { 0.02f }}
|
||||
}
|
||||
},
|
||||
// Convert + Subtract + Multiply
|
||||
{
|
||||
LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, { 128 }, { 0.02f }},
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
{ ngraph::element::f32, { 128 }, { 0.02f }}
|
||||
}
|
||||
},
|
||||
// Convert + Subtract + Multiply
|
||||
{
|
||||
LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{ ngraph::element::f32, {}, { 0.02f }},
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
{ ngraph::element::f32, {}, { 0.02f }}
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
LPT,
|
||||
MaxPoolTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::ValuesIn(shapes),
|
||||
::testing::ValuesIn(testValues)),
|
||||
MaxPoolTransformation::getTestCaseName);
|
||||
|
@ -0,0 +1,71 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
|
||||
#include "functional_test_utils/low_precision_transformations/layer_transformation.hpp"
|
||||
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
|
||||
#include "ngraph_functions/subgraph_builders.hpp"
|
||||
#include "ngraph_functions/low_precision_transformations/common/builders.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
class AddActualValues {
|
||||
public:
|
||||
ngraph::element::Type precision1;
|
||||
std::vector<float> subtractValues1;
|
||||
std::vector<float> mutliplyValues1;
|
||||
ngraph::element::Type precision2;
|
||||
std::vector<float> subtractValues2;
|
||||
std::vector<float> mutliplyValues2;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const AddActualValues& values) {
|
||||
return out <<
|
||||
"_" << values.precision1 <<
|
||||
"_subtract" << values.subtractValues1.size() <<
|
||||
"_mutliply" << values.mutliplyValues1.size() <<
|
||||
"_" << values.precision2 <<
|
||||
"_subtract" << values.subtractValues2.size() <<
|
||||
"_mutliply" << values.mutliplyValues2.size();
|
||||
}
|
||||
|
||||
class AddExpectedValues {
|
||||
public:
|
||||
ngraph::element::Type precision1;
|
||||
std::vector<float> subtractValues1;
|
||||
std::vector<float> mutliplyValues1;
|
||||
ngraph::element::Type precision2;
|
||||
std::vector<float> mutliplyValuesAfter;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const AddExpectedValues& values) {
|
||||
return out <<
|
||||
"_" << values.precision1 <<
|
||||
"_subtract" << values.subtractValues1.size() <<
|
||||
"_mutliply" << values.mutliplyValues1.size() <<
|
||||
"_" << values.precision2 <<
|
||||
"_mutliply" << values.mutliplyValuesAfter.size();
|
||||
}
|
||||
|
||||
class ElementwiseWithMultiParentDequantizationFunction {
|
||||
public:
|
||||
static std::shared_ptr<ngraph::Function> get(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::pass::low_precision::LayerTransformation::Params& params,
|
||||
const ngraph::element::Type& precision1,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
|
||||
const ngraph::element::Type& precision2,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization2);
|
||||
};
|
||||
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
@ -8,6 +8,7 @@
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include "common/fake_quantize_on_data.hpp"
|
||||
#include "low_precision/layer_transformation.hpp"
|
||||
#include "ngraph_functions/low_precision_transformations/common/dequantization_operations.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
@ -15,34 +16,16 @@ namespace subgraph {
|
||||
|
||||
class MaxPoolFunction {
|
||||
public:
|
||||
class ActualValues {
|
||||
public:
|
||||
ngraph::element::Type lowPrecision;
|
||||
std::vector<float> subtractValues;
|
||||
std::vector<float> mutliplyValues;
|
||||
};
|
||||
|
||||
class ExpectedValues {
|
||||
public:
|
||||
ngraph::element::Type activationPrecision;
|
||||
std::vector<float> subtractValues;
|
||||
std::vector<float> mutliplyValues;
|
||||
};
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type originalFunctionPrecision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ActualValues& values);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type originalFunctionPrecision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const FakeQuantizeOnData& fakeQuantizeOnData);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getReference(
|
||||
const ngraph::element::Type originalFunctionPrecision,
|
||||
static std::shared_ptr<ngraph::Function> get(
|
||||
const ngraph::Shape& inputShape,
|
||||
const ExpectedValues& values);
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter);
|
||||
};
|
||||
|
||||
} // namespace subgraph
|
||||
|
@ -0,0 +1,60 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "ngraph_functions/low_precision_transformations/elementwise_with_multi_parent_dequantization_function.hpp"
|
||||
#include "low_precision/network_helper.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include "ngraph_functions/builders.hpp"
|
||||
#include "ngraph_functions/subgraph_builders.hpp"
|
||||
|
||||
using namespace ngraph::pass::low_precision;
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
std::shared_ptr<ngraph::Function> ElementwiseWithMultiParentDequantizationFunction::get(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::pass::low_precision::LayerTransformation::Params& params,
|
||||
const ngraph::element::Type& precision1,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization1,
|
||||
const ngraph::element::Type& precision2,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization2) {
|
||||
const auto input1_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
|
||||
const auto input1_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
|
||||
const std::shared_ptr<ngraph::Node> multiply1 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
|
||||
DequantizationMultiply(
|
||||
ngraph::op::TemporaryReplaceOutputType(input1_1, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(input1_2, element::f32).get()),
|
||||
std::vector<element::Type>{element::f32, element::f32},
|
||||
std::vector<element::Type>{});
|
||||
|
||||
const std::shared_ptr<ngraph::Node> parent1 = dequantization1.empty() ? multiply1 : makeDequantization(multiply1, dequantization1);
|
||||
|
||||
const auto input2_1 = std::make_shared<ngraph::opset1::Parameter>(precision1, inputShape);
|
||||
const auto input2_2 = std::make_shared<ngraph::opset1::Parameter>(precision1, ngraph::Shape({ inputShape[0], inputShape[1], 1, 1 }));
|
||||
const std::shared_ptr<ngraph::Node> multiply2 = std::make_shared<ngraph::op::TypeRelaxed<ngraph::opset1::Multiply>>(
|
||||
DequantizationMultiply(
|
||||
ngraph::op::TemporaryReplaceOutputType(input2_1, element::f32).get(),
|
||||
ngraph::op::TemporaryReplaceOutputType(input2_2, element::f32).get()),
|
||||
std::vector<element::Type>{element::f32, element::f32},
|
||||
std::vector<element::Type>{});
|
||||
|
||||
const std::shared_ptr<ngraph::Node> parent2 = dequantization2.empty() ? multiply2 : makeDequantization(multiply2, dequantization2);
|
||||
|
||||
const auto add = std::make_shared<ngraph::opset1::Add>(parent1, parent2);
|
||||
add->set_friendly_name("output");
|
||||
auto& rtInfo = add->get_rt_info();
|
||||
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("add");
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(add) };
|
||||
ngraph::ParameterVector parameters = { input1_1, input1_2, input2_1, input2_2 };
|
||||
return std::make_shared<ngraph::Function>(results, parameters, "ElementwiseWithMultiParentDequantization");
|
||||
}
|
||||
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
@ -14,41 +14,6 @@ namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
|
||||
const ngraph::element::Type originalFunctionPrecision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ActualValues& values) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(values.lowPrecision, ngraph::Shape(inputShape));
|
||||
std::shared_ptr<ngraph::Node> parent = input;
|
||||
|
||||
const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::opset1::Convert>(parent, originalFunctionPrecision);
|
||||
parent = convert;
|
||||
|
||||
if (!values.subtractValues.empty()) {
|
||||
const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::opset1::Subtract>(
|
||||
parent,
|
||||
std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
|
||||
parent = subtract;
|
||||
}
|
||||
|
||||
const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::opset1::Multiply>(
|
||||
parent,
|
||||
std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
|
||||
parent = multiply;
|
||||
|
||||
const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
|
||||
parent,
|
||||
Strides{ 1, 1 },
|
||||
Shape{ 1, 1 },
|
||||
Shape{ 0, 0 },
|
||||
Shape{ 2, 2 },
|
||||
op::RoundingType::FLOOR);
|
||||
maxPool->set_friendly_name("output");
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(maxPool) };
|
||||
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
|
||||
const ngraph::element::Type originalFunctionPrecision,
|
||||
const ngraph::Shape& inputShape,
|
||||
@ -71,13 +36,16 @@ std::shared_ptr<ngraph::Function> MaxPoolFunction::getOriginal(
|
||||
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> MaxPoolFunction::getReference(
|
||||
const ngraph::element::Type originalFunctionPrecision,
|
||||
std::shared_ptr<ngraph::Function> MaxPoolFunction::get(
|
||||
const ngraph::Shape& inputShape,
|
||||
const ExpectedValues& values) {
|
||||
auto input = std::make_shared<ngraph::opset1::Parameter>(values.activationPrecision, ngraph::Shape(inputShape));
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, ngraph::Shape(inputShape));
|
||||
std::shared_ptr<ngraph::Node> parent = input;
|
||||
|
||||
parent = dequantizationBefore.empty() ? parent : makeDequantization(parent, dequantizationBefore);
|
||||
|
||||
const std::shared_ptr<ngraph::Node> maxPool = std::make_shared<ngraph::opset1::MaxPool>(
|
||||
parent,
|
||||
Strides{ 1, 1 },
|
||||
@ -87,25 +55,16 @@ std::shared_ptr<ngraph::Function> MaxPoolFunction::getReference(
|
||||
op::RoundingType::FLOOR);
|
||||
parent = maxPool;
|
||||
|
||||
if (parent->get_output_element_type(0) != originalFunctionPrecision) {
|
||||
const std::shared_ptr<ngraph::Node> convert = std::make_shared<ngraph::pass::low_precision::DequantizationConvert>(parent, originalFunctionPrecision);
|
||||
parent = convert;
|
||||
}
|
||||
parent = dequantizationAfter.empty() ? maxPool : makeDequantization(maxPool, dequantizationAfter);
|
||||
maxPool->set_friendly_name("maxPool");
|
||||
|
||||
if (!values.subtractValues.empty()) {
|
||||
const std::shared_ptr<ngraph::Node> subtract = std::make_shared<ngraph::pass::low_precision::DequantizationSubtract>(
|
||||
parent,
|
||||
std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.subtractValues.size() }), values.subtractValues));
|
||||
parent = subtract;
|
||||
}
|
||||
const std::shared_ptr<ngraph::opset1::Result> result = std::make_shared<ngraph::opset1::Result>(parent);
|
||||
|
||||
const std::shared_ptr<ngraph::Node> multiply = std::make_shared<ngraph::pass::low_precision::DequantizationMultiply>(
|
||||
parent,
|
||||
std::make_shared<ngraph::opset1::Constant>(originalFunctionPrecision, Shape({ values.mutliplyValues.size() }), values.mutliplyValues));
|
||||
multiply->set_friendly_name("output");
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(multiply) };
|
||||
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "MaxPoolTransformation");
|
||||
const std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
|
||||
ngraph::ResultVector{ result },
|
||||
std::vector<std::shared_ptr<ngraph::op::Parameter>> { input },
|
||||
"MaxPoolTransformation");
|
||||
return function;
|
||||
}
|
||||
|
||||
} // namespace subgraph
|
||||
|
Loading…
Reference in New Issue
Block a user