[LPT] Concat dequantization shape fix (#4702)
This commit is contained in:
parent
dbd3a3d7a4
commit
7ca7e7d9c4
@ -312,9 +312,9 @@ void ConcatTransformation::addDequantizationLayers(
|
||||
convertNodes.push_back(dequantization.convert);
|
||||
}
|
||||
|
||||
const ngraph::element::Type precision = deqPrecision; //dequantization.data.get_element_type();
|
||||
ngraph::Shape targetShape(dequantization.data.get_shape().size(), 1ul);
|
||||
targetShape[1] = dequantization.data.get_shape()[1];
|
||||
const ngraph::element::Type precision = deqPrecision;
|
||||
ngraph::Shape targetShape(layer->get_input_shape(i).size(), 1ul);
|
||||
targetShape[1] = layer->get_input_shape(i)[1];
|
||||
|
||||
if (!allDequantizationShiftAreZero) {
|
||||
subtractNodes.push_back(dequantization.subtract == nullptr ?
|
||||
|
@ -0,0 +1,143 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "layer_transformation.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <transformations/utils/utils.hpp>
|
||||
#include <low_precision/reshape.hpp>
|
||||
#include <low_precision/concat_multi_channels.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "lpt_ngraph_functions/concat_function.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
#include "simple_low_precision_transformer.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph;
|
||||
using namespace ngraph::pass;
|
||||
|
||||
namespace {
|
||||
class ActualValues {
|
||||
public:
|
||||
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const ActualValues& values) {
|
||||
return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2;
|
||||
}
|
||||
|
||||
class ResultValues {
|
||||
public:
|
||||
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize1;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize2;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const ResultValues& values) {
|
||||
return out << "_" << values.fakeQuantize1 << "_" << values.fakeQuantize2 << "_";
|
||||
}
|
||||
|
||||
class TestValues {
|
||||
public:
|
||||
ngraph::Shape inputShape;
|
||||
ngraph::Shape reshapeOutputShape;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
ActualValues actual;
|
||||
ResultValues result;
|
||||
};
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& out, const TestValues& values) {
|
||||
return out << "_" << values.reshapeOutputShape << "_" << values.actual << "_" << values.result;
|
||||
}
|
||||
|
||||
typedef std::tuple <
|
||||
ngraph::element::Type,
|
||||
TestValues
|
||||
> ConcatTransformationParams;
|
||||
|
||||
class ConcatWithIntermediateReshapeTransformation : public LayerTransformation, public testing::WithParamInterface<ConcatTransformationParams> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const ngraph::element::Type precision = std::get<0>(GetParam());
|
||||
TestValues testValues = std::get<1>(GetParam());
|
||||
|
||||
actualFunction = ngraph::builder::subgraph::ConcatFunction::getOriginalWithIntermediateReshape(
|
||||
precision,
|
||||
testValues.inputShape,
|
||||
testValues.reshapeOutputShape,
|
||||
testValues.actual.fakeQuantize1,
|
||||
testValues.actual.fakeQuantize2);
|
||||
|
||||
SimpleLowPrecisionTransformer transform;
|
||||
transform.add<ngraph::pass::low_precision::ConcatMultiChannelsTransformation, ngraph::opset1::Concat>(testValues.params);
|
||||
transform.add<ngraph::pass::low_precision::ReshapeTransformation, ngraph::opset1::Reshape>(testValues.params);
|
||||
transform.transform(actualFunction);
|
||||
|
||||
referenceFunction = ngraph::builder::subgraph::ConcatFunction::getReferenceWithIntermediateReshape(
|
||||
precision,
|
||||
testValues.inputShape,
|
||||
testValues.reshapeOutputShape,
|
||||
testValues.result.fakeQuantize1,
|
||||
testValues.result.fakeQuantize2,
|
||||
testValues.result.dequantizationAfter);
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<ConcatTransformationParams> obj) {
|
||||
const ngraph::element::Type precision = std::get<0>(obj.param);
|
||||
const TestValues testValues = std::get<1>(obj.param);
|
||||
|
||||
std::ostringstream result;
|
||||
result <<
|
||||
LayerTransformation::getTestCaseNameByParams(precision, testValues.inputShape, testValues.params) << "_" <<
|
||||
testValues.reshapeOutputShape << "_" <<
|
||||
testValues.actual << "_" <<
|
||||
testValues.result << "_";
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(ConcatWithIntermediateReshapeTransformation, CompareFunctions) {
|
||||
actualFunction->validate_nodes_and_infer_types();
|
||||
auto res = compare_functions(referenceFunction, actualFunction, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
const std::vector<ngraph::element::Type> precisions = {
|
||||
ngraph::element::f32,
|
||||
// ngraph::element::f16
|
||||
};
|
||||
|
||||
const std::vector<TestValues> testValues = {
|
||||
// U8: Concat + MaxPool
|
||||
{
|
||||
Shape{ 2, 1, 9 },
|
||||
Shape{ 2, 1, 1, 9 },
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {2.55f} },
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {25.5f} }
|
||||
},
|
||||
{
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {2.55f}, {0.f}, {255.f} },
|
||||
{ 256ul, ngraph::Shape({}), {0.f}, {25.5f}, {0.f}, {255.f} },
|
||||
{ {ngraph::element::f32}, {}, { {0.01f, 0.1f} } }
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
smoke_LPT,
|
||||
ConcatWithIntermediateReshapeTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::ValuesIn(testValues)),
|
||||
ConcatWithIntermediateReshapeTransformation::getTestCaseName);
|
||||
} // namespace
|
@ -90,6 +90,13 @@ public:
|
||||
const FakeQuantizeOnDataWithConstant& fqOnData2,
|
||||
const FakeQuantizeOnDataWithConstant& fqOnData3);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getOriginalWithIntermediateReshape(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::Shape& reshapeOutputShape,
|
||||
const FakeQuantizeOnData& fqOnData1,
|
||||
const FakeQuantizeOnData& fqOnData2);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getReference(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
@ -206,6 +213,14 @@ public:
|
||||
const ngraph::element::Type precisionAfterOperation,
|
||||
const DequantizationOperations& dequantizationOperations);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getReferenceWithIntermediateReshape(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::Shape& reshapeOutputShape,
|
||||
const FakeQuantizeOnData& fqOnData1,
|
||||
const FakeQuantizeOnData& fqOnData2,
|
||||
const DequantizationOperations& dequantizationAfter);
|
||||
|
||||
private:
|
||||
static std::shared_ptr<Node> makeMaxPool(const Output<Node>& parent, const std::vector<size_t>& kernel);
|
||||
};
|
||||
|
@ -656,6 +656,43 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithReshapeAtTheEnd
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> ConcatFunction::getOriginalWithIntermediateReshape(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::Shape& reshapeOutputShape,
|
||||
const FakeQuantizeOnData& fqOnData1,
|
||||
const FakeQuantizeOnData& fqOnData2) {
|
||||
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
|
||||
input1->set_friendly_name("input1");
|
||||
const auto fakeQuantize1 = makeFakeQuantize(input1, precision, fqOnData1);
|
||||
const auto reshape1 = std::make_shared<opset1::Reshape>(
|
||||
fakeQuantize1,
|
||||
opset1::Constant::create(element::i64, Shape{reshapeOutputShape.size()}, reshapeOutputShape),
|
||||
true);
|
||||
|
||||
const std::vector<size_t> inputShape2 = inputShape;
|
||||
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape2));
|
||||
input2->set_friendly_name("input2");
|
||||
const auto fakeQuantize2 = makeFakeQuantize(input2, precision, fqOnData2);
|
||||
const auto reshape2 = std::make_shared<opset1::Reshape>(
|
||||
fakeQuantize2,
|
||||
opset1::Constant::create(element::i64, Shape{reshapeOutputShape.size()}, reshapeOutputShape),
|
||||
true);
|
||||
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
|
||||
ngraph::OutputVector{ reshape1->output(0), reshape2->output(0) }, 1);
|
||||
concat->set_friendly_name("output");
|
||||
auto& rtInfo = concat->get_rt_info();
|
||||
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(concat) };
|
||||
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
|
||||
results,
|
||||
ngraph::ParameterVector{ input1, input2 },
|
||||
"ConcatWithIntermediateReshapeTransformation");
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> ConcatFunction::getReference(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
@ -1420,6 +1457,49 @@ std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithReshapeAtTheEn
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> ConcatFunction::getReferenceWithIntermediateReshape(
|
||||
const ngraph::element::Type precision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::Shape& reshapeOutputShape,
|
||||
const FakeQuantizeOnData& fqOnData1,
|
||||
const FakeQuantizeOnData& fqOnData2,
|
||||
const DequantizationOperations& dequantizationAfter) {
|
||||
const auto input1 = std::make_shared<ngraph::opset1::Parameter>(precision, inputShape);
|
||||
input1->set_friendly_name("input1");
|
||||
const auto fakeQuantize1 = makeFakeQuantizeTypeRelaxed(input1, precision, fqOnData1);
|
||||
ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(fakeQuantize1, element::u8);
|
||||
const auto reshape1 = std::make_shared<opset1::Reshape>(
|
||||
fakeQuantize1,
|
||||
opset1::Constant::create(element::i64, Shape{reshapeOutputShape.size()}, reshapeOutputShape),
|
||||
true);
|
||||
|
||||
const std::vector<size_t> inputShape2 = inputShape;
|
||||
const auto input2 = std::make_shared<ngraph::opset1::Parameter>(precision, ngraph::Shape(inputShape2));
|
||||
input2->set_friendly_name("input2");
|
||||
const auto fakeQuantize2 = makeFakeQuantizeTypeRelaxed(input2, precision, fqOnData2);
|
||||
ngraph::pass::low_precision::NetworkHelper::setOutDataPrecision(fakeQuantize2, element::u8);
|
||||
const auto reshape2 = std::make_shared<opset1::Reshape>(
|
||||
fakeQuantize2,
|
||||
opset1::Constant::create(element::i64, Shape{reshapeOutputShape.size()}, reshapeOutputShape),
|
||||
true);
|
||||
const std::shared_ptr<ngraph::opset1::Concat> concat = std::make_shared<ngraph::opset1::Concat>(
|
||||
ngraph::OutputVector{ reshape1->output(0), reshape2->output(0) }, 1);
|
||||
concat->set_friendly_name("output_original");
|
||||
auto& rtInfo = concat->get_rt_info();
|
||||
rtInfo["Variant::std::string"] = std::make_shared<VariantWrapper<std::string>>("concat");
|
||||
|
||||
const auto dequantization = makeDequantization(concat, dequantizationAfter);
|
||||
dequantization->set_friendly_name("output");
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(dequantization) };
|
||||
std::shared_ptr<ngraph::Function> function = std::make_shared<ngraph::Function>(
|
||||
results,
|
||||
ngraph::ParameterVector{ input1, input2 },
|
||||
"ConcatWithIntermediateReshapeTransformation");
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> ConcatFunction::makeMaxPool(const Output<Node>& parent, const std::vector<size_t>& kernel) {
|
||||
const std::vector<size_t> stride = { 1, 1 };
|
||||
const std::vector<size_t> padBegin = { 0, 0 };
|
||||
|
Loading…
Reference in New Issue
Block a user