[LPT] Concat dequantization shape fix (#4702)

This commit is contained in:
Vladimir Zinoviev 2021-03-15 17:47:56 +03:00 committed by GitHub
parent dbd3a3d7a4
commit 7ca7e7d9c4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 241 additions and 3 deletions

View File

@ -312,9 +312,9 @@ void ConcatTransformation::addDequantizationLayers(
convertNodes.push_back(dequantization.convert);
}
const ngraph::element::Type precision = deqPrecision; //dequantization.data.get_element_type();
ngraph::Shape targetShape(dequantization.data.get_shape().size(), 1ul);
targetShape[1] = dequantization.data.get_shape()[1];
const ngraph::element::Type precision = deqPrecision;
ngraph::Shape targetShape(layer->get_input_shape(i).size(), 1ul);
targetShape[1] = layer->get_input_shape(i)[1];
if (!allDequantizationShiftAreZero) {
subtractNodes.push_back(dequantization.subtract == nullptr ?

View File

@ -0,0 +1,143 @@
// Copyright (C) 2021 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "layer_transformation.hpp"
#include <string>
#include <sstream>
#include <memory>
#include <gtest/gtest.h>
#include <transformations/utils/utils.hpp>
#include <low_precision/reshape.hpp>
#include <low_precision/concat_multi_channels.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "lpt_ngraph_functions/concat_function.hpp"
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
#include "simple_low_precision_transformer.hpp"
using namespace testing;
using namespace ngraph;
using namespace ngraph::pass;
namespace {
class ActualValues {
public:
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
};
inline std::ostream& operator<<(std::ostream& out, const ActualValues& values) {
return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2;
}
class ResultValues {
public:
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
};
inline std::ostream& operator<<(std::ostream& out, const ResultValues& values) {
return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2 << "_";
}
class TestValues {
public:
ngraph::Shape inputShape;
ngraph::Shape reshapeOutputShape;
ngraph::pass::low_precision::LayerTransformation::Params params;
ActualValues actual;
ResultValues result;
};
inline std::ostream& operator<<(std::ostream& out, const TestValues& values) {
return out << "_" << values.reshapeOutputShape << "_" << values.actual << "_" << values.result;
}
typedef std::tuple <
ngraph::element::Type,
TestValues
> ConcatTransformationParams;
class ConcatWithIntermediateReshapeTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> {
public:
void SetUp() override {
const ngraph::element::Type precision = std::get<0>(GetParam());
TestValues testValues = std::get<1>(GetParam());
actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginalWithIntermediateReshape(
precision,
testValues.inputShape,
testValues.reshapeOutputShape,
testValues.actual.fakeQuantize1,
testValues.actual.fakeQuantize2);
SimpleLowPrecisionTransformer transform;
transform.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
transform.add<ngraph::pass::low_precision::ReshapeTransformation, ngraph::opset1::Reshape>(testValues.params);
transform.transform(actualFunction);
referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReferenceWithIntermediateReshape(
precision,
testValues.inputShape,
testValues.reshapeOutputShape,
testValues.result.fakeQuantize1,
testValues.result.fakeQuantize2,
testValues.result.dequantizationAfter);
}
static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
const ngraph::element::Type precision = std::get<0>(obj.param);
const TestValues testValues = std::get<1>(obj.param);
std::ostringstream result;
result <<
LayerTransformation::getTestCaseNameByParams(precision, testValues.inputShape, testValues.params) << "_" <<
testValues.reshapeOutputShape << "_" <<
testValues.actual << "_" <<
testValues.result << "_";
return result.str();
}
};
TEST_P(ConcatWithIntermediateReshapeTransformation, CompareFunctions) {
actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(referenceFunction, actualFunction, true, true);
ASSERT_TRUE(res.first) << res.second;
}
const std::vector<ngraph::element::Type> precisions = {
ngraph::element::f32,
// ngraph::element::f16
};
const std::vector<TestValues> testValues = {
// U8: Concat + MaxPool
{
Shape{ 2, 1, 9 },
Shape{ 2, 1, 1, 9 },
LayerTransformation::createParamsU8I8(),
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} }
},
{
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} },
{ {ngraph::element::f32}, {}, { {0.01f, 0.1f} } }
}
},
};
INSTANTIATE_TEST_CASE_P(
smoke_LPT,
ConcatWithIntermediateReshapeTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::ValuesIn(testValues)),
ConcatWithIntermediateReshapeTransformation::getTestCaseName);
} // namespace

View File

@ -90,6 +90,13 @@ public:
const FakeQuantizeOnDataWithConstant& fqOnData2,
const FakeQuantizeOnDataWithConstant& fqOnData3);
static std::shared_ptr<ngraph::Function> getOriginalWithIntermediateReshape(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::Shape& reshapeOutputShape,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2);
static std::shared_ptr<ngraph::Function> getReference(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
@ -206,6 +213,14 @@ public:
const ngraph::element::Type precisionAfterOperation,
const DequantizationOperations& dequantizationOperations);
static std::shared_ptr<ngraph::Function> getReferenceWithIntermediateReshape(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::Shape& reshapeOutputShape,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2,
const DequantizationOperations& dequantizationAfter);
private:
static std::shared_ptr<Node> makeMaxPool(const Output<Node>& parent, const std::vector<size_t>& kernel);
};

View File

@ -656,6 +656,43 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithReshapeAtTheEnd
return function;
}
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithIntermediateReshape(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::Shape& reshapeOutputShape,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2) {
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
input1->set_friendly_name("input1");
const auto fakeQuantize1 = makeFakeQuantize(input1, precision, fqOnData1);
const auto reshape1 = std::make_shared<opset1::Reshape>(
fakeQuantize1,
opset1::Constant::create(element::i64, Shape{reshapeOutputShape.size()}, reshapeOutputShape),
true);
const std::vector<size_t> inputShape2 = inputShape;
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape2));
input2->set_friendly_name("input2");
const auto fakeQuantize2 = makeFakeQuantize(input2, precision, fqOnData2);
const auto reshape2 = std::make_shared<opset1::Reshape>(
fakeQuantize2,
opset1::Constant::create(element::i64, Shape{reshapeOutputShape.size()}, reshapeOutputShape),
true);
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
ngraph::OutputVector{ reshape1->output(0), reshape2->output(0) }, 1);
concat->set_friendly_name("output");
auto& rtInfo = concat->get_rt_info();
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(concat) };
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
results,
ngraph::ParameterVector{ input1, input2 },
"ConcatWithIntermediateReshapeTransformation");
return function;
}
std::shared_ptr<ngraph::Function> ConcatFunction::getReference(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
@ -1420,6 +1457,49 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithReshapeAtTheEn
return function;
}
std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithIntermediateReshape(
const ngraph::element::Type precision,
const ngraph::Shape& inputShape,
const ngraph::Shape& reshapeOutputShape,
const FakeQuantizeOnData& fqOnData1,
const FakeQuantizeOnData& fqOnData2,
const DequantizationOperations& dequantizationAfter) {
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
input1->set_friendly_name("input1");
const auto fakeQuantize1 = makeFakeQuantizeTypeRelaxed(input1, precision, fqOnData1);
ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(fakeQuantize1, element::u8);
const auto reshape1 = std::make_shared<opset1::Reshape>(
fakeQuantize1,
opset1::Constant::create(element::i64, Shape{reshapeOutputShape.size()}, reshapeOutputShape),
true);
const std::vector<size_t> inputShape2 = inputShape;
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape2));
input2->set_friendly_name("input2");
const auto fakeQuantize2 = makeFakeQuantizeTypeRelaxed(input2, precision, fqOnData2);
ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(fakeQuantize2, element::u8);
const auto reshape2 = std::make_shared<opset1::Reshape>(
fakeQuantize2,
opset1::Constant::create(element::i64, Shape{reshapeOutputShape.size()}, reshapeOutputShape),
true);
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
ngraph::OutputVector{ reshape1->output(0), reshape2->output(0) }, 1);
concat->set_friendly_name("output_original");
auto& rtInfo = concat->get_rt_info();
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
const auto dequantization = makeDequantization(concat, dequantizationAfter);
dequantization->set_friendly_name("output");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(dequantization) };
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
results,
ngraph::ParameterVector{ input1, input2 },
"ConcatWithIntermediateReshapeTransformation");
return function;
}
std::shared_ptr<Node> ConcatFunction::makeMaxPool(const Output<Node>& parent, const std::vector<size_t>& kernel) {
const std::vector<size_t> stride = { 1, 1 };
const std::vector<size_t> padBegin = { 0, 0 };