Add gather lpt transformation (#14597)

* Add gather lpt transformation

* Add per-channel gather lpt dequantization support

* Fix review comments

* Add GPU test case

* Fix clang-format error gpu case  build error

* Fix comments

* Fix clang-format check fail

* Update docs

* Fix comments

* Add Gather opset1 quantization support
This commit is contained in:
Mang Guo 2023-02-02 10:13:52 -05:00 committed by GitHub
parent d86ba0742c
commit 9e83b081f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 952 additions and 0 deletions

View File

@ -44,6 +44,7 @@
<tab type="user" title="FakeQuantizeDecompositionTransformation" url="@ref openvino_docs_OV_UG_lpt_FakeQuantizeDecompositionTransformation"/>
<tab type="user" title="FakeQuantizeTransformation" url="@ref openvino_docs_OV_UG_lpt_FakeQuantizeTransformation"/>
<tab type="user" title="InterpolateTransformation" url="@ref openvino_docs_OV_UG_lpt_InterpolateTransformation"/>
<tab type="user" title="GatherTransformation" url="@ref openvino_docs_OV_UG_lpt_GatherTransformation"/>
<tab type="user" title="GroupConvolutionTransformation" url="@ref openvino_docs_OV_UG_lpt_GroupConvolutionTransformation"/>
<tab type="user" title="MatMulTransformation" url="@ref openvino_docs_OV_UG_lpt_MatMulTransformation"/>
<tab type="user" title="MaxPoolTransformation" url="@ref openvino_docs_OV_UG_lpt_MaxPoolTransformation"/>

View File

@ -60,6 +60,8 @@ LPT transformations propagate dequantization operations through the following op
* [Squeeze-1](@ref openvino_docs_ops_shape_Reshape_1)
* [StridedSlice-1](@ref openvino_docs_ops_movement_StridedSlice_1)
* [Transpose-1](@ref openvino_docs_ops_movement_Transpose_1)
* [Gather-7](@ref openvino_docs_ops_movement_Gather_7)
* [Gather-8](@ref openvino_docs_ops_movement_Gather_8)
* [Unsqueeze-1](@ref openvino_docs_ops_shape_Unsqueeze_1)
* [VariadicSplit-1](@ref openvino_docs_ops_movement_VariadicSplit_1)
@ -149,6 +151,7 @@ This step has the most transformations. These transformations can be separated i
* [FakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_FakeQuantizeTransformation)
* [InterpolateTransformation](@ref openvino_docs_OV_UG_lpt_InterpolateTransformation)
* [GroupConvolutionTransformation](@ref openvino_docs_OV_UG_lpt_GroupConvolutionTransformation)
* [GatherTransformation](@ref openvino_docs_OV_UG_lpt_GatherTransformation)
* [MatMulTransformation](@ref openvino_docs_OV_UG_lpt_MatMulTransformation)
* [MaxPoolTransformation](@ref openvino_docs_OV_UG_lpt_MaxPoolTransformation)
* [MultiplyTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyTransformation)

View File

@ -12,6 +12,7 @@ Main transformations are the majority of low precision transformations. Transfor
* [FakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_FakeQuantizeTransformation)
* [InterpolateTransformation](@ref openvino_docs_OV_UG_lpt_InterpolateTransformation)
* [GroupConvolutionTransformation](@ref openvino_docs_OV_UG_lpt_GroupConvolutionTransformation)
* [GatherTransformation](@ref openvino_docs_OV_UG_lpt_GatherTransformation)
* [MatMulTransformation](@ref openvino_docs_OV_UG_lpt_MatMulTransformation)
* [MaxPoolTransformation](@ref openvino_docs_OV_UG_lpt_MaxPoolTransformation)
* [MultiplyTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyTransformation)

View File

@ -0,0 +1,3 @@
# GatherTransformation transformation {#openvino_docs_OV_UG_lpt_GatherTransformation}
ngraph::pass::low_precision::GatherTransformation class represents the `Gather` operation transformation.

View File

@ -23,6 +23,7 @@ openvino_docs_OV_UG_lpt_fusemultiplytofakequantizetransformation.rst
openvino_docs_OV_UG_lpt_fusesubtracttofakequantizetransformation.rst
openvino_docs_OV_UG_lpt_groupconvolutiontransformation.rst
openvino_docs_OV_UG_lpt_interpolatetransformation.rst
openvino_docs_OV_UG_lpt_gathertransformation.rst
openvino_docs_OV_UG_lpt_linopsequencefusion.rst
openvino_docs_OV_UG_lpt_mvntransformation.rst
openvino_docs_OV_UG_lpt_markupavgpoolprecisionpreserved.rst

View File

@ -0,0 +1,25 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <algorithm>
#include "low_precision/layer_transformation.hpp"
namespace ngraph {
namespace pass {
namespace low_precision {
class LP_TRANSFORMATIONS_API GatherTransformation : public LayerTransformation {
public:
OPENVINO_RTTI("GatherTransformation", "0");
GatherTransformation(const Params& params = Params());
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const override;
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
};
} // namespace low_precision
} // namespace pass
} // namespace ngraph

View File

@ -0,0 +1,194 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "low_precision/gather.hpp"
#include <memory>
#include <ngraph/ngraph.hpp>
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>
#include <ngraph/pattern/op/or.hpp>
#include <ngraph/pattern/op/wrap_type.hpp>
#include "low_precision/network_helper.hpp"
#include "low_precision/rt_info/precision_preserved_attribute.hpp"
#include "itt.hpp"
namespace ngraph {
namespace pass {
namespace low_precision {
std::shared_ptr<opset1::Constant> gatherDeqConstant(
const std::shared_ptr<ngraph::Node> &gather,
const std::shared_ptr<ngraph::Node> &dequantizationConstant) {
auto constant = ov::as_type_ptr<ngraph::opset1::Constant>(dequantizationConstant);
auto constantShape = constant->get_shape();
if (shape_size(constantShape) == 1ul) {
return NetworkHelper::toScalar(constant);
}
const auto rank = gather->get_input_partial_shape(0).size();
if (rank != constantShape.size()) {
// case when constShape without batch
while ((constantShape.size() > 1) && (constantShape.size() < rank)) {
constantShape.insert(constantShape.begin(), 1);
}
const auto newConstant = fold<ngraph::opset1::Broadcast>(
constant,
ngraph::opset1::Constant::create(ngraph::element::i32, { constantShape.size() }, constantShape));
constant = ov::as_type_ptr<ngraph::opset1::Constant>(newConstant);
}
const int64_t axis = ov::as_type_ptr<opset1::Constant>(gather->get_input_node_shared_ptr(2))->cast_vector<int64_t>()[0];
const size_t normalizedAxis = normalize_axis(gather->get_friendly_name(), axis, gather->get_input_partial_shape(0).rank());
// Dequantization channel matches with gather axis
if (constantShape[normalizedAxis] != 1ul) {
const auto gather1 = ov::as_type_ptr<ngraph::opset1::Gather>(gather);
if (gather1) {
const auto output = fold<ngraph::opset1::Gather>(
constant,
gather1->input_value(1),
gather1->input_value(2));
constant = ov::as_type_ptr<opset1::Constant>(NetworkHelper::toScalarIfPossible(output));
}
const auto gather7 = ov::as_type_ptr<ngraph::opset7::Gather>(gather);
if (gather7) {
const auto output = fold<ngraph::opset7::Gather>(
constant,
gather7->input_value(1),
gather7->input_value(2),
gather7->get_batch_dims());
constant = ov::as_type_ptr<opset1::Constant>(NetworkHelper::toScalarIfPossible(output));
}
const auto gather8 = ov::as_type_ptr<ngraph::opset8::Gather>(gather);
if (gather8) {
const auto output = fold<ngraph::opset8::Gather>(
constant,
gather8->input_value(1),
gather8->input_value(2),
gather8->get_batch_dims());
constant = ov::as_type_ptr<opset1::Constant>(NetworkHelper::toScalarIfPossible(output));
}
}
return constant;
}
GatherTransformation::GatherTransformation(const Params& params) : LayerTransformation(params) {
MATCHER_SCOPE(GatherTransformation);
auto gather = pattern::wrap_type<opset1::Gather, opset7::Gather, opset8::Gather>({ pattern::wrap_type<opset1::Multiply>(),
pattern::any_input(),
pattern::any_input() });
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
auto op = m.get_match_root();
if (transformation_callback(op)) {
return false;
}
return transform(*context, m);
};
auto m = std::make_shared<ngraph::pattern::Matcher>(gather, matcher_name);
this->register_matcher(m, callback);
}
bool GatherTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) {
auto node = m.get_match_root();
if (!canBeTransformed(context, m.get_match_root())) {
return false;
}
const std::shared_ptr<Node> gather = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(gather, defaultPrecisions);
if (dequantization.multiply != nullptr) {
const auto newConstant = gatherDeqConstant(gather, dequantization.multiplyConstant);
replace_node(dequantization.multiplyConstant, newConstant);
}
if (dequantization.subtract != nullptr) {
const auto newConstant = gatherDeqConstant(gather, dequantization.subtractConstant);
replace_node(dequantization.subtractConstant, newConstant);
}
moveDequantizationAfter(context, gather, NetworkHelper::getDequantization(gather, defaultPrecisions), false);
return true;
}
bool GatherTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
if (!LayerTransformation::canBeTransformed(context, operation)) {
return false;
}
auto dequantization = NetworkHelper::getDequantization(operation, defaultPrecisions);
if (dequantization.empty()) {
return false;
}
const auto isScalar = [&] {
if (dequantization.multiply != nullptr) {
if (!NetworkHelper::isScalarLike(dequantization.multiplyConstant)) {
return false;
}
}
if (dequantization.subtract != nullptr) {
if (!NetworkHelper::isScalarLike(dequantization.subtractConstant)) {
return false;
}
}
return true;
}();
if (isScalar) {
return true;
}
// If dequantization constant is not scalar, Gather axis must be constant.
// If the Gather axis matches with dequantization channel, the Gather indices
// must be constant and have 0D or 1D shape so we can do folding.
const auto axisConstant = ov::as_type_ptr<opset1::Constant>(operation->get_input_node_shared_ptr(2));
if (axisConstant == nullptr) {
return false;
}
if (operation->get_input_partial_shape(0).rank().is_dynamic()) {
return false;
}
const auto canBeFolded = [&](const std::shared_ptr<ngraph::Node> dequantizationConstant) {
auto constantShape = dequantizationConstant->get_shape();
const auto rank = operation->get_input_partial_shape(0).size();
if (rank != constantShape.size()) {
while ((constantShape.size() > 1) && (constantShape.size() < rank)) {
constantShape.insert(constantShape.begin(), 1);
}
}
const int64_t axis = axisConstant->cast_vector<int64_t>()[0];
const size_t normalizedAxis = normalize_axis(operation->get_friendly_name(), axis, operation->get_input_partial_shape(0).rank());
if (constantShape[normalizedAxis] != 1ul) {
const auto indicesConstant = ov::as_type_ptr<opset1::Constant>(operation->get_input_node_shared_ptr(1));
if (indicesConstant == nullptr)
return false;
const auto indicesShape = indicesConstant->get_shape();
if (indicesShape.size() != 0 && indicesShape.size() != 1) {
return false;
}
}
return true;
};
if ((dequantization.multiply && !canBeFolded(dequantization.multiplyConstant)) ||
(dequantization.subtract && !canBeFolded(dequantization.subtractConstant))) {
return false;
}
return true;
}
bool GatherTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const {
return true;
}
} // namespace low_precision
} // namespace pass
} // namespace ngraph

View File

@ -69,6 +69,7 @@
#include "low_precision/shuffle_channels.hpp"
#include "low_precision/strided_slice.hpp"
#include "low_precision/transpose.hpp"
#include "low_precision/gather.hpp"
#include "low_precision/unsqueeze.hpp"
#include "low_precision/variadic_split.hpp"
#include "low_precision/move_fake_quantize.hpp"
@ -246,6 +247,7 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p
ADD_MATCHER(common, SplitTransformation, params)
ADD_MATCHER(common, StridedSliceTransformation, params)
ADD_MATCHER(common, TransposeTransformation, params)
ADD_MATCHER(common, GatherTransformation, params)
ADD_MATCHER(common, UnsqueezeTransformation, params)
ADD_MATCHER(common, VariadicSplitTransformation, params)

View File

@ -0,0 +1,275 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <gtest/gtest.h>
#include <low_precision/gather.hpp>
#include <memory>
#include <sstream>
#include <string>
#include <transformations/init_node_info.hpp>
#include <transformations/utils/utils.hpp>
#include "common_test_utils/ngraph_test_utils.hpp"
#include "layer_transformation.hpp"
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
#include "lpt_ngraph_functions/gather_function.hpp"
#include "simple_low_precision_transformer.hpp"
namespace {
using namespace testing;
using namespace ngraph::pass;
using namespace ngraph;
class GatherTransformationTestValues {
public:
class Actual {
public:
ngraph::element::Type precisionBeforeDequantization;
ngraph::builder::subgraph::DequantizationOperations dequantization;
};
class Expected {
public:
ngraph::element::Type precisionBeforeDequantization;
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
ngraph::element::Type precisionAfterOperation;
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
};
std::vector<size_t> gatherIndicesShape;
std::vector<int> gatherIndicesValues;
std::vector<int> axis;
int64_t batch_dims;
TestTransformationParams params;
Actual actual;
Expected expected;
};
typedef std::tuple<ngraph::PartialShape, GatherTransformationTestValues, int> GatherTransformationParams;
class GatherTransformation : public LayerTransformation,
public testing::WithParamInterface<GatherTransformationParams> {
public:
void SetUp() override {
const ngraph::PartialShape inputShape = std::get<0>(GetParam());
const GatherTransformationTestValues testValues = std::get<1>(GetParam());
const int opset_version = std::get<2>(GetParam());
actualFunction =
ngraph::builder::subgraph::GatherFunction::getOriginal(inputShape,
testValues.gatherIndicesShape,
testValues.gatherIndicesValues,
testValues.axis,
testValues.batch_dims,
testValues.actual.precisionBeforeDequantization,
testValues.actual.dequantization,
opset_version);
SimpleLowPrecisionTransformer transformer;
transformer.add<ngraph::pass::low_precision::GatherTransformation, ngraph::opset1::Gather>(testValues.params);
transformer.transform(actualFunction);
referenceFunction =
ngraph::builder::subgraph::GatherFunction::getReference(inputShape,
testValues.gatherIndicesShape,
testValues.gatherIndicesValues,
testValues.axis,
testValues.batch_dims,
testValues.expected.precisionBeforeDequantization,
testValues.expected.dequantizationBefore,
testValues.expected.precisionAfterOperation,
testValues.expected.dequantizationAfter,
opset_version);
}
static std::string getTestCaseName(testing::TestParamInfo<GatherTransformationParams> obj) {
const ngraph::PartialShape inputShape = std::get<0>(obj.param);
const GatherTransformationTestValues testValues = std::get<1>(obj.param);
const int opset_version = std::get<2>(obj.param);
std::ostringstream result;
result << "_" << inputShape << "_" << testValues.gatherIndicesShape << "_" << testValues.gatherIndicesValues
<< "_" << testValues.axis << "_" << testValues.batch_dims << "_"
<< testValues.actual.precisionBeforeDequantization << "_" << testValues.actual.dequantization << "_"
<< testValues.expected.dequantizationBefore << "_" << opset_version;
return result.str();
}
};
TEST_P(GatherTransformation, CompareFunctions) {
ov::pass::InitNodeInfo().run_on_model(actualFunction);
actualFunction->validate_nodes_and_infer_types();
auto res = compare_functions(actualFunction, referenceFunction, true, true);
ASSERT_TRUE(res.first) << res.second;
ASSERT_TRUE(LayerTransformation::allNamesAreUnique(actualFunction)) << "Not all names are unique";
}
namespace testValues1 {
const std::vector<int> opset_version = {1, 7, 8};
const std::vector<ngraph::PartialShape> inputShapes3D = {{3, 3, 4}, {-1, -1, -1}};
const std::vector<GatherTransformationTestValues> testValues = {
// U8: per-tensor quantization
{{1},
{0},
{0},
std::int64_t{0},
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8,
{{ngraph::element::f32}, {{128}, ngraph::element::f32, {}, true, 1, ngraph::element::u8, true}, {0.1f}}},
{ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32}, {{128}, ngraph::element::f32, {}, true, 1, ngraph::element::u8, true}, {0.1f}}}},
// U8: per-tensor quantization
{{2},
{0, 1},
{0},
std::int64_t{0},
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}},
{ngraph::element::u8, {{}, {}, {}}, ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}},
// U8: per-tensor quantization
{{3, 2},
{1, 2, 1, 2, 1, 2},
{1},
std::int64_t{1},
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}},
{ngraph::element::u8, {{}, {}, {}}, ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}},
// U8: per-channel quantization with the same values
{{1},
{0},
{0},
std::int64_t{0},
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8,
{{ngraph::element::f32},
{{128.f}, element::undefined, {1, 3, 1}, false, 1ul, element::u8, true},
{{0.1}, ngraph::element::f32, {1, 3, 1}}}},
{ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32},
{{128.f}, element::undefined, {1, 3, 1}, false, 1ul, element::u8, true},
{{0.1}, ngraph::element::f32, {1, 3, 1}}}}},
// U8: per-channel quantization, gather axis match with channel
{{1},
{0},
{1}, // axis
std::int64_t{0},
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8,
{{ngraph::element::f32},
{{128, 64, 32}, ngraph::element::f32, {1, 3, 1}},
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}},
{ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32}, {{128}, ngraph::element::f32, {}}, {{0.3f}, ngraph::element::f32, {}}}}},
// U8: per-channel quantization, gather axis match with channel, quantization constant shape size is
// less than input shape
{{1},
{1},
{1}, // axis
std::int64_t{0},
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8,
{{ngraph::element::f32},
{{128, 64, 32}, ngraph::element::f32, {3, 1}},
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {3, 1}}}},
{ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32}, {{64}, ngraph::element::f32, {}}, {{0.2f}, ngraph::element::f32, {}}}}},
// U8: per-channel quantization, gather axis and channel doesn't match
{{1},
{0},
{0},
std::int64_t{0},
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8,
{{ngraph::element::f32},
{{128, 64, 32}, ngraph::element::f32, {1, 3, 1}},
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}},
{ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32},
{{128, 64, 32}, ngraph::element::f32, {1, 3, 1}},
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}}},
// U8: per-channel quantization, negative axis, gather axis match with channel
{{1},
{0},
{-2}, // axis
std::int64_t{0},
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8,
{{ngraph::element::f32},
{{128, 64, 32}, ngraph::element::f32, {1, 3, 1}},
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}},
{ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32}, {{128}, ngraph::element::f32, {}}, {{0.3f}, ngraph::element::f32, {}}}}},
// empty
{{1},
{0},
{0},
std::int64_t{0},
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8, {}},
{ngraph::element::u8, {}, ngraph::element::u8, {}}},
};
INSTANTIATE_TEST_SUITE_P(smoke_LPT,
GatherTransformation,
::testing::Combine(::testing::ValuesIn(inputShapes3D),
::testing::ValuesIn(testValues),
::testing::ValuesIn(opset_version)),
GatherTransformation::getTestCaseName);
} // namespace testValues1
namespace testValues2 {
const std::vector<int> opset_version = {8};
const std::vector<ngraph::PartialShape> inputShapes3D = {{3, 3, 4}, {-1, -1, -1}};
const std::vector<GatherTransformationTestValues> testValues = {
// U8: per-tensor quantization, negative indices value
{{3, 2},
{-2, 2, -2, 2, -2, 2}, // indices value
{1},
std::int64_t{1},
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}},
{ngraph::element::u8, {{}, {}, {}}, ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}},
// U8: per-channel quantization, negative indices value, gather axis match with channel
{{1},
{-1}, // indices value
{1}, // axis
std::int64_t{0},
LayerTransformation::createParamsU8I8(),
{ngraph::element::u8,
{{ngraph::element::f32},
{{128, 64, 32}, ngraph::element::f32, {1, 3, 1}},
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}},
{ngraph::element::u8,
{{}, {}, {}},
ngraph::element::u8,
{{ngraph::element::f32}, {{32}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}}},
};
INSTANTIATE_TEST_SUITE_P(smoke_LPT,
GatherTransformation,
::testing::Combine(::testing::ValuesIn(inputShapes3D),
::testing::ValuesIn(testValues),
::testing::ValuesIn(opset_version)),
GatherTransformation::getTestCaseName);
} // namespace testValues2
} // namespace

View File

@ -0,0 +1,89 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "low_precision_transformations/gather_transformation.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
namespace {
const std::vector<ngraph::element::Type> precisions = {
ngraph::element::f32,
};
const std::vector<int> opset_version = {
1, 7, 8
};
const std::vector<GatherTransformationTestValues> testValues = {
// U8: per-tensor quantization
{
{3, 3, 4},
{1},
{0},
{0},
std::int64_t{0},
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
ngraph::element::f32,
{256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}}
},
// U8: per-channel quantization
{
{1, 3, 5},
{1},
{0},
{0},
std::int64_t{0},
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
ngraph::element::f32,
{
256,
{1, 3, 1},
{0.f, 0.f, 0.f},
{25.5f, 25.5f, 25.5f},
{0.f, 12.5f, 25.5f},
{25.5f, 25.5f + 12.5f * 2, 25.5f + 12.5f * 4}
}
},
// U8: per-channel quantization, axis match with dequantization channel, dequantization constant shape is less than gather input shape
{
{1, 3, 4},
{1},
{0},
{1},
std::int64_t{0},
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
ngraph::element::f32,
{
256,
{3, 1},
{0.f, 0.f, 0.f},
{25.5f, 25.5f, 25.5f},
{0.f, 12.5f, 25.5f},
{25.5f, 25.5f + 12.5f * 2, 25.5f + 12.5f * 4}
}
},
// 4D
{
{3, 4, 100, 2},
{2},
{1, 2},
{0},
std::int64_t{0},
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
ngraph::element::f32,
{256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}}
},
};
INSTANTIATE_TEST_SUITE_P(smoke_LPT, GatherTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::Values(CommonTestUtils::DEVICE_CPU),
::testing::ValuesIn(testValues),
::testing::ValuesIn(opset_version)),
GatherTransformation::getTestCaseName);
} // namespace

View File

@ -0,0 +1,72 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include "low_precision_transformations/gather_transformation.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
namespace {
const std::vector<ngraph::element::Type> precisions = {
ngraph::element::f32,
ngraph::element::f16
};
const std::vector<int> opset_version = {
1, 7, 8
};
const std::vector<GatherTransformationTestValues> testValues = {
// U8: per-tensor quantization
{
{3, 3, 4},
{1},
{0},
{0},
std::int64_t{0},
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
ngraph::element::f32,
{256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}}
},
// U8: per-channel quantization
{
{1, 3, 5},
{1},
{0},
{0},
std::int64_t{0},
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
ngraph::element::f32,
{
256,
{1, 3, 1},
{0.f, 0.f, 0.f},
{25.5f, 25.5f, 25.5f},
{0.f, 12.5f, 25.5f},
{25.5f, 25.5f + 12.5f * 2, 25.5f + 12.5f * 4}
}
},
// 4D
{
{3, 4, 100, 2},
{2},
{1, 2},
{0},
std::int64_t{0},
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
ngraph::element::f32,
{256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}}
},
};
INSTANTIATE_TEST_SUITE_P(smoke_LPT, GatherTransformation,
::testing::Combine(
::testing::ValuesIn(precisions),
::testing::Values(CommonTestUtils::DEVICE_GPU),
::testing::ValuesIn(testValues),
::testing::ValuesIn(opset_version)),
GatherTransformation::getTestCaseName);
} // namespace

View File

@ -0,0 +1,43 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include <memory>
#include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp"
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
namespace LayerTestsDefinitions {
class GatherTransformationTestValues {
public:
ngraph::PartialShape inputShape;
std::vector<size_t> gatherIndicesShape;
std::vector<int> gatherIndicesValues;
std::vector<int> axis;
int64_t batch_dims;
ngraph::pass::low_precision::LayerTransformation::Params params;
ngraph::element::Type precisionBeforeFq;
ngraph::builder::subgraph::FakeQuantizeOnData fqOnData;
};
typedef std::tuple<
ngraph::element::Type,
std::string,
GatherTransformationTestValues,
int> GatherTransformationParams;
class GatherTransformation :
public testing::WithParamInterface<GatherTransformationParams>,
public LayerTestsUtils::LayerTransformation {
public:
static std::string getTestCaseName(const testing::TestParamInfo<GatherTransformationParams>& obj);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,56 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "low_precision_transformations/gather_transformation.hpp"
#include <memory>
#include <tuple>
#include <vector>
#include <string>
#include <ie_core.hpp>
#include <transformations/init_node_info.hpp>
#include "lpt_ngraph_functions/gather_function.hpp"
namespace LayerTestsDefinitions {
std::string GatherTransformation::getTestCaseName(const testing::TestParamInfo<GatherTransformationParams>& obj) {
ngraph::element::Type precision;
std::string targetDevice;
GatherTransformationTestValues testValues;
int opset_version;
std::tie(precision, targetDevice, testValues, opset_version) = obj.param;
std::ostringstream result;
result <<
precision << "_" <<
targetDevice << "_" <<
testValues.inputShape << "_" <<
opset_version;
return result.str();
}
void GatherTransformation::SetUp() {
ngraph::element::Type precision;
GatherTransformationTestValues testValues;
int opset_version;
std::tie(precision, targetDevice, testValues, opset_version) = this->GetParam();
function = ngraph::builder::subgraph::GatherFunction::getOriginal(
testValues.inputShape,
testValues.gatherIndicesShape,
testValues.gatherIndicesValues,
testValues.axis,
testValues.batch_dims,
testValues.precisionBeforeFq,
testValues.fqOnData,
opset_version);
}
TEST_P(GatherTransformation, CompareWithRefImpl) {
Run();
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,54 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <vector>
#include <ngraph/ngraph.hpp>
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
namespace ngraph {
namespace builder {
namespace subgraph {
class GatherFunction {
public:
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::PartialShape& inputShape,
const std::vector<size_t>& gatherIndicesShape,
const std::vector<int>& gatherIndicesValues,
const std::vector<int>& axis,
const int64_t batch_dims,
const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const int opset_version);
static std::shared_ptr<ngraph::Function> getOriginal(
const ngraph::PartialShape& inputShape,
const std::vector<size_t>& gatherIndicesShape,
const std::vector<int>& gatherIndicesValues,
const std::vector<int>& axis,
const int64_t batch_dims,
const ngraph::element::Type precisionBeforeFq,
const FakeQuantizeOnData& fqOnData,
const int opset_version);
static std::shared_ptr<ngraph::Function> getReference(
const ngraph::PartialShape& inputShape,
const std::vector<size_t>& gatherIndicesShape,
const std::vector<int>& gatherIndicesValues,
const std::vector<int>& axis,
const int64_t batch_dims,
const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
const ngraph::element::Type precisionAfterOperation,
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter,
const int opset_version);
};
} // namespace subgraph
} // namespace builder
} // namespace ngraph

View File

@ -0,0 +1,133 @@
// Copyright (C) 2018-2022 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "lpt_ngraph_functions/gather_function.hpp"
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset7.hpp>
#include <ngraph/opsets/opset8.hpp>
#include "lpt_ngraph_functions/common/builders.hpp"
namespace ngraph {
namespace builder {
namespace subgraph {
std::shared_ptr<ngraph::Function> GatherFunction::getOriginal(
const ngraph::PartialShape& inputShape,
const std::vector<size_t>& gatherIndicesShape,
const std::vector<int>& gatherIndicesValues,
const std::vector<int>& axis,
const int64_t batch_dims,
const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
const int opset_version) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, inputShape);
const std::shared_ptr<Node> dequantizationOp = makeDequantization(input, dequantization);
const auto indicesNode = std::make_shared<ngraph::opset1::Constant>(
ngraph::element::i64,
ngraph::Shape(gatherIndicesShape),
gatherIndicesValues);
const auto axisNode = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{ axis.size() }, axis);
std::shared_ptr<Node> gather;
if (opset_version == 7) {
gather = std::make_shared<ngraph::opset7::Gather>(dequantizationOp, indicesNode, axisNode, batch_dims);
} else if (opset_version == 8) {
gather = std::make_shared<ngraph::opset8::Gather>(dequantizationOp, indicesNode, axisNode, batch_dims);
} else if (opset_version == 1) {
gather = std::make_shared<ngraph::opset1::Gather>(dequantizationOp, indicesNode, axisNode);
} else {
throw std::runtime_error("Unknown opset version");
}
gather->set_friendly_name("output");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(gather) };
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "GatherFunction");
}
std::shared_ptr<ngraph::Function> GatherFunction::getOriginal(
const ngraph::PartialShape& inputShape,
const std::vector<size_t>& gatherIndicesShape,
const std::vector<int>& gatherIndicesValues,
const std::vector<int>& axis,
const int64_t batch_dims,
const ngraph::element::Type precisionBeforeFq,
const FakeQuantizeOnData& fqOnData,
const int opset_version) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeFq, inputShape);
const std::shared_ptr<Node> quantizationOp = fqOnData.empty() ?
std::dynamic_pointer_cast<ngraph::Node>(input) :
makeFakeQuantize(input, precisionBeforeFq, fqOnData);
const auto indicesNode = std::make_shared<ngraph::opset1::Constant>(
ngraph::element::i64,
ngraph::Shape(gatherIndicesShape),
gatherIndicesValues);
const auto axisNode = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{ axis.size() }, axis);
std::shared_ptr<Node> gather;
if (opset_version == 7) {
gather = std::make_shared<ngraph::opset7::Gather>(quantizationOp, indicesNode, axisNode, batch_dims);
} else if (opset_version == 8) {
gather = std::make_shared<ngraph::opset8::Gather>(quantizationOp, indicesNode, axisNode, batch_dims);
} else if (opset_version == 1) {
gather = std::make_shared<ngraph::opset1::Gather>(quantizationOp, indicesNode, axisNode);
} else {
throw std::runtime_error("Unknown opset version");
}
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(gather) };
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "GatherFunction");
}
std::shared_ptr<ngraph::Function> GatherFunction::getReference(
const ngraph::PartialShape& inputShape,
const std::vector<size_t>& gatherIndicesShape,
const std::vector<int>& gatherIndicesValues,
const std::vector<int>& axis,
const int64_t batch_dims,
const ngraph::element::Type precisionBeforeDequantization,
const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
const ngraph::element::Type precisionAfterOperation,
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter,
const int opset_version) {
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, inputShape);
const std::shared_ptr<Node> quantizationOpBefore = makeDequantization(input, dequantizationBefore);
const auto indicesNode = std::make_shared<ngraph::opset1::Constant>(
ngraph::element::i64,
ngraph::Shape(gatherIndicesShape),
gatherIndicesValues);
const auto axisNode = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{ axis.size() }, axis);
std::shared_ptr<Node> gather;
if (opset_version == 7) {
gather = std::make_shared<ngraph::opset7::Gather>(quantizationOpBefore, indicesNode, axisNode, batch_dims);
} else if (opset_version == 8) {
gather = std::make_shared<ngraph::opset8::Gather>(quantizationOpBefore, indicesNode, axisNode, batch_dims);
} else if (opset_version == 1) {
gather = std::make_shared<ngraph::opset1::Gather>(quantizationOpBefore, indicesNode, axisNode);
} else {
throw std::runtime_error("Unknown opset version");
}
if (quantizationOpBefore->get_output_element_type(0) != precisionAfterOperation) {
THROW_IE_LPT_EXCEPTION(*quantizationOpBefore) << "unexpected precision '" << precisionAfterOperation << "' after operation";
}
if (gather->get_output_element_type(0) != precisionAfterOperation) {
THROW_IE_LPT_EXCEPTION(*gather) << "unexpected precision '" << precisionAfterOperation << "' after operation";
}
const std::shared_ptr<Node> quantizationOpAfter = makeDequantization(gather, dequantizationAfter);
quantizationOpAfter->set_friendly_name("output");
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(quantizationOpAfter) };
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "GatherFunction");
}
} // namespace subgraph
} // namespace builder
} // namespace ngraph