From 9e83b081f418e30988390ab2ee6665cf9c4411b2 Mon Sep 17 00:00:00 2001 From: Mang Guo Date: Thu, 2 Feb 2023 10:13:52 -0500 Subject: [PATCH] Add gather lpt transformation (#14597) * Add gather lpt transformation * Add per-channel gather lpt dequantization support * Fix review comments * Add GPU test case * Fix clang-format error gpu case build error * Fix comments * Fix clang-format check fail * Update docs * Fix comments * Add Gather opset1 quantization support --- docs/IE_PLUGIN_DG/layout.xml | 1 + .../low_precision_transformations/lpt.md | 3 + .../pipeline/step3_main.md | 1 + .../step3_main/movement/gather.md | 3 + docs/doxygen-xfail.txt | 1 + .../include/low_precision/gather.hpp | 25 ++ .../src/gather.cpp | 194 ++++++++++++ .../src/low_precision.cpp | 2 + .../tests/gather_transformation.cpp | 275 ++++++++++++++++++ .../gather_transformation.cpp | 89 ++++++ .../gather_transformation.cpp | 72 +++++ .../gather_transformation.hpp | 43 +++ .../gather_transformation.cpp | 56 ++++ .../lpt_ngraph_functions/gather_function.hpp | 54 ++++ .../src/gather_function.cpp | 133 +++++++++ 15 files changed, 952 insertions(+) create mode 100644 docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/transformations/step3_main/movement/gather.md create mode 100644 src/common/low_precision_transformations/include/low_precision/gather.hpp create mode 100644 src/common/low_precision_transformations/src/gather.cpp create mode 100644 src/common/low_precision_transformations/tests/gather_transformation.cpp create mode 100644 src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/gather_transformation.cpp create mode 100644 src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/gather_transformation.cpp create mode 100644 src/tests/functional/plugin/shared/include/low_precision_transformations/gather_transformation.hpp create mode 100644 src/tests/functional/plugin/shared/src/low_precision_transformations/gather_transformation.cpp create mode 100644 src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/gather_function.hpp create mode 100644 src/tests/ngraph_helpers/lpt_ngraph_functions/src/gather_function.cpp diff --git a/docs/IE_PLUGIN_DG/layout.xml b/docs/IE_PLUGIN_DG/layout.xml index c1e1bfa233d..6ba8aeb8750 100644 --- a/docs/IE_PLUGIN_DG/layout.xml +++ b/docs/IE_PLUGIN_DG/layout.xml @@ -44,6 +44,7 @@ + diff --git a/docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/lpt.md b/docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/lpt.md index 59510c2395c..c695f20ce15 100644 --- a/docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/lpt.md +++ b/docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/lpt.md @@ -60,6 +60,8 @@ LPT transformations propagate dequantization operations through the following op * [Squeeze-1](@ref openvino_docs_ops_shape_Reshape_1) * [StridedSlice-1](@ref openvino_docs_ops_movement_StridedSlice_1) * [Transpose-1](@ref openvino_docs_ops_movement_Transpose_1) +* [Gather-7](@ref openvino_docs_ops_movement_Gather_7) +* [Gather-8](@ref openvino_docs_ops_movement_Gather_8) * [Unsqueeze-1](@ref openvino_docs_ops_shape_Unsqueeze_1) * [VariadicSplit-1](@ref openvino_docs_ops_movement_VariadicSplit_1) @@ -149,6 +151,7 @@ This step has the most transformations. These transformations can be separated i * [FakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_FakeQuantizeTransformation) * [InterpolateTransformation](@ref openvino_docs_OV_UG_lpt_InterpolateTransformation) * [GroupConvolutionTransformation](@ref openvino_docs_OV_UG_lpt_GroupConvolutionTransformation) +* [GatherTransformation](@ref openvino_docs_OV_UG_lpt_GatherTransformation) * [MatMulTransformation](@ref openvino_docs_OV_UG_lpt_MatMulTransformation) * [MaxPoolTransformation](@ref openvino_docs_OV_UG_lpt_MaxPoolTransformation) * [MultiplyTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyTransformation) diff --git a/docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/pipeline/step3_main.md b/docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/pipeline/step3_main.md index e91b5c8be7f..850a5e6c4dd 100644 --- a/docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/pipeline/step3_main.md +++ b/docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/pipeline/step3_main.md @@ -12,6 +12,7 @@ Main transformations are the majority of low precision transformations. Transfor * [FakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_FakeQuantizeTransformation) * [InterpolateTransformation](@ref openvino_docs_OV_UG_lpt_InterpolateTransformation) * [GroupConvolutionTransformation](@ref openvino_docs_OV_UG_lpt_GroupConvolutionTransformation) +* [GatherTransformation](@ref openvino_docs_OV_UG_lpt_GatherTransformation) * [MatMulTransformation](@ref openvino_docs_OV_UG_lpt_MatMulTransformation) * [MaxPoolTransformation](@ref openvino_docs_OV_UG_lpt_MaxPoolTransformation) * [MultiplyTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyTransformation) diff --git a/docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/transformations/step3_main/movement/gather.md b/docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/transformations/step3_main/movement/gather.md new file mode 100644 index 00000000000..7e4f5904254 --- /dev/null +++ b/docs/IE_PLUGIN_DG/plugin_transformation_pipeline/low_precision_transformations/transformations/step3_main/movement/gather.md @@ -0,0 +1,3 @@ +# GatherTransformation transformation {#openvino_docs_OV_UG_lpt_GatherTransformation} + +ngraph::pass::low_precision::GatherTransformation class represents the `Gather` operation transformation. diff --git a/docs/doxygen-xfail.txt b/docs/doxygen-xfail.txt index f13c90aa78f..f3f1abd334e 100644 --- a/docs/doxygen-xfail.txt +++ b/docs/doxygen-xfail.txt @@ -23,6 +23,7 @@ openvino_docs_OV_UG_lpt_fusemultiplytofakequantizetransformation.rst openvino_docs_OV_UG_lpt_fusesubtracttofakequantizetransformation.rst openvino_docs_OV_UG_lpt_groupconvolutiontransformation.rst openvino_docs_OV_UG_lpt_interpolatetransformation.rst +openvino_docs_OV_UG_lpt_gathertransformation.rst openvino_docs_OV_UG_lpt_linopsequencefusion.rst openvino_docs_OV_UG_lpt_mvntransformation.rst openvino_docs_OV_UG_lpt_markupavgpoolprecisionpreserved.rst diff --git a/src/common/low_precision_transformations/include/low_precision/gather.hpp b/src/common/low_precision_transformations/include/low_precision/gather.hpp new file mode 100644 index 00000000000..a5158dccbf2 --- /dev/null +++ b/src/common/low_precision_transformations/include/low_precision/gather.hpp @@ -0,0 +1,25 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include "low_precision/layer_transformation.hpp" + +namespace ngraph { +namespace pass { +namespace low_precision { + +class LP_TRANSFORMATIONS_API GatherTransformation : public LayerTransformation { +public: + OPENVINO_RTTI("GatherTransformation", "0"); + GatherTransformation(const Params& params = Params()); + bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override; + bool isPrecisionPreserved(std::shared_ptr layer) const override; + bool canBeTransformed(const TransformationContext& context, std::shared_ptr layer) const override; +}; + +} // namespace low_precision +} // namespace pass +} // namespace ngraph diff --git a/src/common/low_precision_transformations/src/gather.cpp b/src/common/low_precision_transformations/src/gather.cpp new file mode 100644 index 00000000000..daaad0b3b27 --- /dev/null +++ b/src/common/low_precision_transformations/src/gather.cpp @@ -0,0 +1,194 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "low_precision/gather.hpp" + +#include +#include +#include +#include +#include +#include +#include + +#include "low_precision/network_helper.hpp" +#include "low_precision/rt_info/precision_preserved_attribute.hpp" +#include "itt.hpp" + +namespace ngraph { +namespace pass { +namespace low_precision { + +std::shared_ptr gatherDeqConstant( + const std::shared_ptr &gather, + const std::shared_ptr &dequantizationConstant) { + auto constant = ov::as_type_ptr(dequantizationConstant); + auto constantShape = constant->get_shape(); + if (shape_size(constantShape) == 1ul) { + return NetworkHelper::toScalar(constant); + } + + const auto rank = gather->get_input_partial_shape(0).size(); + if (rank != constantShape.size()) { + // case when constShape without batch + while ((constantShape.size() > 1) && (constantShape.size() < rank)) { + constantShape.insert(constantShape.begin(), 1); + } + const auto newConstant = fold( + constant, + ngraph::opset1::Constant::create(ngraph::element::i32, { constantShape.size() }, constantShape)); + constant = ov::as_type_ptr(newConstant); + } + + const int64_t axis = ov::as_type_ptr(gather->get_input_node_shared_ptr(2))->cast_vector()[0]; + const size_t normalizedAxis = normalize_axis(gather->get_friendly_name(), axis, gather->get_input_partial_shape(0).rank()); + + // Dequantization channel matches with gather axis + if (constantShape[normalizedAxis] != 1ul) { + const auto gather1 = ov::as_type_ptr(gather); + if (gather1) { + const auto output = fold( + constant, + gather1->input_value(1), + gather1->input_value(2)); + constant = ov::as_type_ptr(NetworkHelper::toScalarIfPossible(output)); + } + + const auto gather7 = ov::as_type_ptr(gather); + if (gather7) { + const auto output = fold( + constant, + gather7->input_value(1), + gather7->input_value(2), + gather7->get_batch_dims()); + constant = ov::as_type_ptr(NetworkHelper::toScalarIfPossible(output)); + } + + const auto gather8 = ov::as_type_ptr(gather); + if (gather8) { + const auto output = fold( + constant, + gather8->input_value(1), + gather8->input_value(2), + gather8->get_batch_dims()); + constant = ov::as_type_ptr(NetworkHelper::toScalarIfPossible(output)); + } + } + return constant; +} + +GatherTransformation::GatherTransformation(const Params& params) : LayerTransformation(params) { + MATCHER_SCOPE(GatherTransformation); + auto gather = pattern::wrap_type({ pattern::wrap_type(), + pattern::any_input(), + pattern::any_input() }); + + ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) { + auto op = m.get_match_root(); + if (transformation_callback(op)) { + return false; + } + return transform(*context, m); + }; + + auto m = std::make_shared(gather, matcher_name); + this->register_matcher(m, callback); +} + +bool GatherTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) { + auto node = m.get_match_root(); + if (!canBeTransformed(context, m.get_match_root())) { + return false; + } + + const std::shared_ptr gather = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions); + FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(gather, defaultPrecisions); + + if (dequantization.multiply != nullptr) { + const auto newConstant = gatherDeqConstant(gather, dequantization.multiplyConstant); + replace_node(dequantization.multiplyConstant, newConstant); + } + if (dequantization.subtract != nullptr) { + const auto newConstant = gatherDeqConstant(gather, dequantization.subtractConstant); + replace_node(dequantization.subtractConstant, newConstant); + } + + moveDequantizationAfter(context, gather, NetworkHelper::getDequantization(gather, defaultPrecisions), false); + return true; +} + +bool GatherTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr operation) const { + if (!LayerTransformation::canBeTransformed(context, operation)) { + return false; + } + + auto dequantization = NetworkHelper::getDequantization(operation, defaultPrecisions); + if (dequantization.empty()) { + return false; + } + + const auto isScalar = [&] { + if (dequantization.multiply != nullptr) { + if (!NetworkHelper::isScalarLike(dequantization.multiplyConstant)) { + return false; + } + } + if (dequantization.subtract != nullptr) { + if (!NetworkHelper::isScalarLike(dequantization.subtractConstant)) { + return false; + } + } + return true; + }(); + if (isScalar) { + return true; + } + + // If dequantization constant is not scalar, Gather axis must be constant. + // If the Gather axis matches with dequantization channel, the Gather indices + // must be constant and have 0D or 1D shape so we can do folding. + const auto axisConstant = ov::as_type_ptr(operation->get_input_node_shared_ptr(2)); + if (axisConstant == nullptr) { + return false; + } + + if (operation->get_input_partial_shape(0).rank().is_dynamic()) { + return false; + } + const auto canBeFolded = [&](const std::shared_ptr dequantizationConstant) { + auto constantShape = dequantizationConstant->get_shape(); + const auto rank = operation->get_input_partial_shape(0).size(); + if (rank != constantShape.size()) { + while ((constantShape.size() > 1) && (constantShape.size() < rank)) { + constantShape.insert(constantShape.begin(), 1); + } + } + const int64_t axis = axisConstant->cast_vector()[0]; + const size_t normalizedAxis = normalize_axis(operation->get_friendly_name(), axis, operation->get_input_partial_shape(0).rank()); + if (constantShape[normalizedAxis] != 1ul) { + const auto indicesConstant = ov::as_type_ptr(operation->get_input_node_shared_ptr(1)); + if (indicesConstant == nullptr) + return false; + const auto indicesShape = indicesConstant->get_shape(); + if (indicesShape.size() != 0 && indicesShape.size() != 1) { + return false; + } + } + return true; + }; + + if ((dequantization.multiply && !canBeFolded(dequantization.multiplyConstant)) || + (dequantization.subtract && !canBeFolded(dequantization.subtractConstant))) { + return false; + } + return true; +} + +bool GatherTransformation::isPrecisionPreserved(std::shared_ptr layer) const { + return true; +} + +} // namespace low_precision +} // namespace pass +} // namespace ngraph diff --git a/src/common/low_precision_transformations/src/low_precision.cpp b/src/common/low_precision_transformations/src/low_precision.cpp index dd642dc22cf..fae55a8ab93 100644 --- a/src/common/low_precision_transformations/src/low_precision.cpp +++ b/src/common/low_precision_transformations/src/low_precision.cpp @@ -69,6 +69,7 @@ #include "low_precision/shuffle_channels.hpp" #include "low_precision/strided_slice.hpp" #include "low_precision/transpose.hpp" +#include "low_precision/gather.hpp" #include "low_precision/unsqueeze.hpp" #include "low_precision/variadic_split.hpp" #include "low_precision/move_fake_quantize.hpp" @@ -246,6 +247,7 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p ADD_MATCHER(common, SplitTransformation, params) ADD_MATCHER(common, StridedSliceTransformation, params) ADD_MATCHER(common, TransposeTransformation, params) + ADD_MATCHER(common, GatherTransformation, params) ADD_MATCHER(common, UnsqueezeTransformation, params) ADD_MATCHER(common, VariadicSplitTransformation, params) diff --git a/src/common/low_precision_transformations/tests/gather_transformation.cpp b/src/common/low_precision_transformations/tests/gather_transformation.cpp new file mode 100644 index 00000000000..e9ca8d361af --- /dev/null +++ b/src/common/low_precision_transformations/tests/gather_transformation.cpp @@ -0,0 +1,275 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include +#include +#include +#include + +#include "common_test_utils/ngraph_test_utils.hpp" +#include "layer_transformation.hpp" +#include "lpt_ngraph_functions/common/dequantization_operations.hpp" +#include "lpt_ngraph_functions/gather_function.hpp" +#include "simple_low_precision_transformer.hpp" + +namespace { +using namespace testing; +using namespace ngraph::pass; +using namespace ngraph; + +class GatherTransformationTestValues { +public: + class Actual { + public: + ngraph::element::Type precisionBeforeDequantization; + ngraph::builder::subgraph::DequantizationOperations dequantization; + }; + + class Expected { + public: + ngraph::element::Type precisionBeforeDequantization; + ngraph::builder::subgraph::DequantizationOperations dequantizationBefore; + ngraph::element::Type precisionAfterOperation; + ngraph::builder::subgraph::DequantizationOperations dequantizationAfter; + }; + + std::vector gatherIndicesShape; + std::vector gatherIndicesValues; + std::vector axis; + int64_t batch_dims; + TestTransformationParams params; + Actual actual; + Expected expected; +}; + +typedef std::tuple GatherTransformationParams; + +class GatherTransformation : public LayerTransformation, + public testing::WithParamInterface { +public: + void SetUp() override { + const ngraph::PartialShape inputShape = std::get<0>(GetParam()); + const GatherTransformationTestValues testValues = std::get<1>(GetParam()); + const int opset_version = std::get<2>(GetParam()); + + actualFunction = + ngraph::builder::subgraph::GatherFunction::getOriginal(inputShape, + testValues.gatherIndicesShape, + testValues.gatherIndicesValues, + testValues.axis, + testValues.batch_dims, + testValues.actual.precisionBeforeDequantization, + testValues.actual.dequantization, + opset_version); + + SimpleLowPrecisionTransformer transformer; + transformer.add(testValues.params); + transformer.transform(actualFunction); + + referenceFunction = + ngraph::builder::subgraph::GatherFunction::getReference(inputShape, + testValues.gatherIndicesShape, + testValues.gatherIndicesValues, + testValues.axis, + testValues.batch_dims, + testValues.expected.precisionBeforeDequantization, + testValues.expected.dequantizationBefore, + testValues.expected.precisionAfterOperation, + testValues.expected.dequantizationAfter, + opset_version); + } + + static std::string getTestCaseName(testing::TestParamInfo obj) { + const ngraph::PartialShape inputShape = std::get<0>(obj.param); + const GatherTransformationTestValues testValues = std::get<1>(obj.param); + const int opset_version = std::get<2>(obj.param); + + std::ostringstream result; + result << "_" << inputShape << "_" << testValues.gatherIndicesShape << "_" << testValues.gatherIndicesValues + << "_" << testValues.axis << "_" << testValues.batch_dims << "_" + << testValues.actual.precisionBeforeDequantization << "_" << testValues.actual.dequantization << "_" + << testValues.expected.dequantizationBefore << "_" << opset_version; + return result.str(); + } +}; + +TEST_P(GatherTransformation, CompareFunctions) { + ov::pass::InitNodeInfo().run_on_model(actualFunction); + actualFunction->validate_nodes_and_infer_types(); + auto res = compare_functions(actualFunction, referenceFunction, true, true); + ASSERT_TRUE(res.first) << res.second; + + ASSERT_TRUE(LayerTransformation::allNamesAreUnique(actualFunction)) << "Not all names are unique"; +} + +namespace testValues1 { +const std::vector opset_version = {1, 7, 8}; + +const std::vector inputShapes3D = {{3, 3, 4}, {-1, -1, -1}}; + +const std::vector testValues = { + // U8: per-tensor quantization + {{1}, + {0}, + {0}, + std::int64_t{0}, + LayerTransformation::createParamsU8I8(), + {ngraph::element::u8, + {{ngraph::element::f32}, {{128}, ngraph::element::f32, {}, true, 1, ngraph::element::u8, true}, {0.1f}}}, + {ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + {{ngraph::element::f32}, {{128}, ngraph::element::f32, {}, true, 1, ngraph::element::u8, true}, {0.1f}}}}, + // U8: per-tensor quantization + {{2}, + {0, 1}, + {0}, + std::int64_t{0}, + LayerTransformation::createParamsU8I8(), + {ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}, + {ngraph::element::u8, {{}, {}, {}}, ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}}, + // U8: per-tensor quantization + {{3, 2}, + {1, 2, 1, 2, 1, 2}, + {1}, + std::int64_t{1}, + LayerTransformation::createParamsU8I8(), + {ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}, + {ngraph::element::u8, {{}, {}, {}}, ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}}, + // U8: per-channel quantization with the same values + {{1}, + {0}, + {0}, + std::int64_t{0}, + LayerTransformation::createParamsU8I8(), + {ngraph::element::u8, + {{ngraph::element::f32}, + {{128.f}, element::undefined, {1, 3, 1}, false, 1ul, element::u8, true}, + {{0.1}, ngraph::element::f32, {1, 3, 1}}}}, + {ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + {{ngraph::element::f32}, + {{128.f}, element::undefined, {1, 3, 1}, false, 1ul, element::u8, true}, + {{0.1}, ngraph::element::f32, {1, 3, 1}}}}}, + // U8: per-channel quantization, gather axis match with channel + {{1}, + {0}, + {1}, // axis + std::int64_t{0}, + LayerTransformation::createParamsU8I8(), + {ngraph::element::u8, + {{ngraph::element::f32}, + {{128, 64, 32}, ngraph::element::f32, {1, 3, 1}}, + {{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}}, + {ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + {{ngraph::element::f32}, {{128}, ngraph::element::f32, {}}, {{0.3f}, ngraph::element::f32, {}}}}}, + // U8: per-channel quantization, gather axis match with channel, quantization constant shape size is + // less than input shape + {{1}, + {1}, + {1}, // axis + std::int64_t{0}, + LayerTransformation::createParamsU8I8(), + {ngraph::element::u8, + {{ngraph::element::f32}, + {{128, 64, 32}, ngraph::element::f32, {3, 1}}, + {{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {3, 1}}}}, + {ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + {{ngraph::element::f32}, {{64}, ngraph::element::f32, {}}, {{0.2f}, ngraph::element::f32, {}}}}}, + // U8: per-channel quantization, gather axis and channel doesn't match + {{1}, + {0}, + {0}, + std::int64_t{0}, + LayerTransformation::createParamsU8I8(), + {ngraph::element::u8, + {{ngraph::element::f32}, + {{128, 64, 32}, ngraph::element::f32, {1, 3, 1}}, + {{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}}, + {ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + {{ngraph::element::f32}, + {{128, 64, 32}, ngraph::element::f32, {1, 3, 1}}, + {{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}}}, + // U8: per-channel quantization, negative axis, gather axis match with channel + {{1}, + {0}, + {-2}, // axis + std::int64_t{0}, + LayerTransformation::createParamsU8I8(), + {ngraph::element::u8, + {{ngraph::element::f32}, + {{128, 64, 32}, ngraph::element::f32, {1, 3, 1}}, + {{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}}, + {ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + {{ngraph::element::f32}, {{128}, ngraph::element::f32, {}}, {{0.3f}, ngraph::element::f32, {}}}}}, + // empty + {{1}, + {0}, + {0}, + std::int64_t{0}, + LayerTransformation::createParamsU8I8(), + {ngraph::element::u8, {}}, + {ngraph::element::u8, {}, ngraph::element::u8, {}}}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_LPT, + GatherTransformation, + ::testing::Combine(::testing::ValuesIn(inputShapes3D), + ::testing::ValuesIn(testValues), + ::testing::ValuesIn(opset_version)), + GatherTransformation::getTestCaseName); +} // namespace testValues1 + +namespace testValues2 { +const std::vector opset_version = {8}; + +const std::vector inputShapes3D = {{3, 3, 4}, {-1, -1, -1}}; + +const std::vector testValues = { + // U8: per-tensor quantization, negative indices value + {{3, 2}, + {-2, 2, -2, 2, -2, 2}, // indices value + {1}, + std::int64_t{1}, + LayerTransformation::createParamsU8I8(), + {ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}, + {ngraph::element::u8, {{}, {}, {}}, ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}}, + // U8: per-channel quantization, negative indices value, gather axis match with channel + {{1}, + {-1}, // indices value + {1}, // axis + std::int64_t{0}, + LayerTransformation::createParamsU8I8(), + {ngraph::element::u8, + {{ngraph::element::f32}, + {{128, 64, 32}, ngraph::element::f32, {1, 3, 1}}, + {{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}}, + {ngraph::element::u8, + {{}, {}, {}}, + ngraph::element::u8, + {{ngraph::element::f32}, {{32}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}}}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_LPT, + GatherTransformation, + ::testing::Combine(::testing::ValuesIn(inputShapes3D), + ::testing::ValuesIn(testValues), + ::testing::ValuesIn(opset_version)), + GatherTransformation::getTestCaseName); +} // namespace testValues2 + +} // namespace diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/gather_transformation.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/gather_transformation.cpp new file mode 100644 index 00000000000..8e30950ee54 --- /dev/null +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/gather_transformation.cpp @@ -0,0 +1,89 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "low_precision_transformations/gather_transformation.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +namespace { +const std::vector precisions = { + ngraph::element::f32, +}; + +const std::vector opset_version = { + 1, 7, 8 +}; + +const std::vector testValues = { + // U8: per-tensor quantization + { + {3, 3, 4}, + {1}, + {0}, + {0}, + std::int64_t{0}, + LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(), + ngraph::element::f32, + {256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}} + }, + // U8: per-channel quantization + { + {1, 3, 5}, + {1}, + {0}, + {0}, + std::int64_t{0}, + LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(), + ngraph::element::f32, + { + 256, + {1, 3, 1}, + {0.f, 0.f, 0.f}, + {25.5f, 25.5f, 25.5f}, + {0.f, 12.5f, 25.5f}, + {25.5f, 25.5f + 12.5f * 2, 25.5f + 12.5f * 4} + } + }, + // U8: per-channel quantization, axis match with dequantization channel, dequantization constant shape is less than gather input shape + { + {1, 3, 4}, + {1}, + {0}, + {1}, + std::int64_t{0}, + LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(), + ngraph::element::f32, + { + 256, + {3, 1}, + {0.f, 0.f, 0.f}, + {25.5f, 25.5f, 25.5f}, + {0.f, 12.5f, 25.5f}, + {25.5f, 25.5f + 12.5f * 2, 25.5f + 12.5f * 4} + } + }, + // 4D + { + {3, 4, 100, 2}, + {2}, + {1, 2}, + {0}, + std::int64_t{0}, + LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(), + ngraph::element::f32, + {256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}} + }, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_LPT, GatherTransformation, + ::testing::Combine( + ::testing::ValuesIn(precisions), + ::testing::Values(CommonTestUtils::DEVICE_CPU), + ::testing::ValuesIn(testValues), + ::testing::ValuesIn(opset_version)), + GatherTransformation::getTestCaseName); +} // namespace diff --git a/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/gather_transformation.cpp b/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/gather_transformation.cpp new file mode 100644 index 00000000000..f1f2a590ce6 --- /dev/null +++ b/src/tests/functional/plugin/gpu/shared_tests_instances/low_precision_transformations/gather_transformation.cpp @@ -0,0 +1,72 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include "low_precision_transformations/gather_transformation.hpp" +#include "common_test_utils/test_constants.hpp" + +using namespace LayerTestsDefinitions; + +namespace { +const std::vector precisions = { + ngraph::element::f32, + ngraph::element::f16 +}; + +const std::vector opset_version = { + 1, 7, 8 +}; + +const std::vector testValues = { + // U8: per-tensor quantization + { + {3, 3, 4}, + {1}, + {0}, + {0}, + std::int64_t{0}, + LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(), + ngraph::element::f32, + {256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}} + }, + // U8: per-channel quantization + { + {1, 3, 5}, + {1}, + {0}, + {0}, + std::int64_t{0}, + LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(), + ngraph::element::f32, + { + 256, + {1, 3, 1}, + {0.f, 0.f, 0.f}, + {25.5f, 25.5f, 25.5f}, + {0.f, 12.5f, 25.5f}, + {25.5f, 25.5f + 12.5f * 2, 25.5f + 12.5f * 4} + } + }, + // 4D + { + {3, 4, 100, 2}, + {2}, + {1, 2}, + {0}, + std::int64_t{0}, + LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(), + ngraph::element::f32, + {256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}} + }, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_LPT, GatherTransformation, + ::testing::Combine( + ::testing::ValuesIn(precisions), + ::testing::Values(CommonTestUtils::DEVICE_GPU), + ::testing::ValuesIn(testValues), + ::testing::ValuesIn(opset_version)), + GatherTransformation::getTestCaseName); +} // namespace diff --git a/src/tests/functional/plugin/shared/include/low_precision_transformations/gather_transformation.hpp b/src/tests/functional/plugin/shared/include/low_precision_transformations/gather_transformation.hpp new file mode 100644 index 00000000000..b700a72226c --- /dev/null +++ b/src/tests/functional/plugin/shared/include/low_precision_transformations/gather_transformation.hpp @@ -0,0 +1,43 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include + +#include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp" +#include "lpt_ngraph_functions/common/dequantization_operations.hpp" + +namespace LayerTestsDefinitions { + +class GatherTransformationTestValues { +public: + ngraph::PartialShape inputShape; + std::vector gatherIndicesShape; + std::vector gatherIndicesValues; + std::vector axis; + int64_t batch_dims; + ngraph::pass::low_precision::LayerTransformation::Params params; + ngraph::element::Type precisionBeforeFq; + ngraph::builder::subgraph::FakeQuantizeOnData fqOnData; +}; + +typedef std::tuple< + ngraph::element::Type, + std::string, + GatherTransformationTestValues, + int> GatherTransformationParams; + +class GatherTransformation : + public testing::WithParamInterface, + public LayerTestsUtils::LayerTransformation { +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj); + +protected: + void SetUp() override; +}; + +} // namespace LayerTestsDefinitions diff --git a/src/tests/functional/plugin/shared/src/low_precision_transformations/gather_transformation.cpp b/src/tests/functional/plugin/shared/src/low_precision_transformations/gather_transformation.cpp new file mode 100644 index 00000000000..682ce15fad7 --- /dev/null +++ b/src/tests/functional/plugin/shared/src/low_precision_transformations/gather_transformation.cpp @@ -0,0 +1,56 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "low_precision_transformations/gather_transformation.hpp" + +#include +#include +#include +#include +#include + +#include +#include "lpt_ngraph_functions/gather_function.hpp" + +namespace LayerTestsDefinitions { + +std::string GatherTransformation::getTestCaseName(const testing::TestParamInfo& obj) { + ngraph::element::Type precision; + std::string targetDevice; + GatherTransformationTestValues testValues; + int opset_version; + std::tie(precision, targetDevice, testValues, opset_version) = obj.param; + + std::ostringstream result; + result << + precision << "_" << + targetDevice << "_" << + testValues.inputShape << "_" << + opset_version; + + return result.str(); +} + +void GatherTransformation::SetUp() { + ngraph::element::Type precision; + GatherTransformationTestValues testValues; + int opset_version; + std::tie(precision, targetDevice, testValues, opset_version) = this->GetParam(); + + function = ngraph::builder::subgraph::GatherFunction::getOriginal( + testValues.inputShape, + testValues.gatherIndicesShape, + testValues.gatherIndicesValues, + testValues.axis, + testValues.batch_dims, + testValues.precisionBeforeFq, + testValues.fqOnData, + opset_version); +} + +TEST_P(GatherTransformation, CompareWithRefImpl) { + Run(); +}; + +} // namespace LayerTestsDefinitions diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/gather_function.hpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/gather_function.hpp new file mode 100644 index 00000000000..fe233dced8f --- /dev/null +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/include/lpt_ngraph_functions/gather_function.hpp @@ -0,0 +1,54 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include +#include +#include +#include "lpt_ngraph_functions/common/dequantization_operations.hpp" +#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp" + +namespace ngraph { +namespace builder { +namespace subgraph { + +class GatherFunction { +public: + static std::shared_ptr getOriginal( + const ngraph::PartialShape& inputShape, + const std::vector& gatherIndicesShape, + const std::vector& gatherIndicesValues, + const std::vector& axis, + const int64_t batch_dims, + const ngraph::element::Type precisionBeforeDequantization, + const ngraph::builder::subgraph::DequantizationOperations& dequantization, + const int opset_version); + + static std::shared_ptr getOriginal( + const ngraph::PartialShape& inputShape, + const std::vector& gatherIndicesShape, + const std::vector& gatherIndicesValues, + const std::vector& axis, + const int64_t batch_dims, + const ngraph::element::Type precisionBeforeFq, + const FakeQuantizeOnData& fqOnData, + const int opset_version); + + static std::shared_ptr getReference( + const ngraph::PartialShape& inputShape, + const std::vector& gatherIndicesShape, + const std::vector& gatherIndicesValues, + const std::vector& axis, + const int64_t batch_dims, + const ngraph::element::Type precisionBeforeDequantization, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore, + const ngraph::element::Type precisionAfterOperation, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter, + const int opset_version); +}; + +} // namespace subgraph +} // namespace builder +} // namespace ngraph diff --git a/src/tests/ngraph_helpers/lpt_ngraph_functions/src/gather_function.cpp b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/gather_function.cpp new file mode 100644 index 00000000000..45ee1436994 --- /dev/null +++ b/src/tests/ngraph_helpers/lpt_ngraph_functions/src/gather_function.cpp @@ -0,0 +1,133 @@ +// Copyright (C) 2018-2022 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "lpt_ngraph_functions/gather_function.hpp" + +#include +#include +#include +#include "lpt_ngraph_functions/common/builders.hpp" + +namespace ngraph { +namespace builder { +namespace subgraph { + +std::shared_ptr GatherFunction::getOriginal( + const ngraph::PartialShape& inputShape, + const std::vector& gatherIndicesShape, + const std::vector& gatherIndicesValues, + const std::vector& axis, + const int64_t batch_dims, + const ngraph::element::Type precisionBeforeDequantization, + const ngraph::builder::subgraph::DequantizationOperations& dequantization, + const int opset_version) { + const auto input = std::make_shared(precisionBeforeDequantization, inputShape); + const std::shared_ptr dequantizationOp = makeDequantization(input, dequantization); + const auto indicesNode = std::make_shared( + ngraph::element::i64, + ngraph::Shape(gatherIndicesShape), + gatherIndicesValues); + const auto axisNode = std::make_shared(ngraph::element::i64, ngraph::Shape{ axis.size() }, axis); + std::shared_ptr gather; + if (opset_version == 7) { + gather = std::make_shared(dequantizationOp, indicesNode, axisNode, batch_dims); + } else if (opset_version == 8) { + gather = std::make_shared(dequantizationOp, indicesNode, axisNode, batch_dims); + } else if (opset_version == 1) { + gather = std::make_shared(dequantizationOp, indicesNode, axisNode); + } else { + throw std::runtime_error("Unknown opset version"); + } + gather->set_friendly_name("output"); + + ngraph::ResultVector results{ std::make_shared(gather) }; + return std::make_shared(results, ngraph::ParameterVector{ input }, "GatherFunction"); +} + +std::shared_ptr GatherFunction::getOriginal( + const ngraph::PartialShape& inputShape, + const std::vector& gatherIndicesShape, + const std::vector& gatherIndicesValues, + const std::vector& axis, + const int64_t batch_dims, + const ngraph::element::Type precisionBeforeFq, + const FakeQuantizeOnData& fqOnData, + const int opset_version) { + + const auto input = std::make_shared(precisionBeforeFq, inputShape); + + const std::shared_ptr quantizationOp = fqOnData.empty() ? + std::dynamic_pointer_cast(input) : + makeFakeQuantize(input, precisionBeforeFq, fqOnData); + + const auto indicesNode = std::make_shared( + ngraph::element::i64, + ngraph::Shape(gatherIndicesShape), + gatherIndicesValues); + const auto axisNode = std::make_shared(ngraph::element::i64, ngraph::Shape{ axis.size() }, axis); + + std::shared_ptr gather; + if (opset_version == 7) { + gather = std::make_shared(quantizationOp, indicesNode, axisNode, batch_dims); + } else if (opset_version == 8) { + gather = std::make_shared(quantizationOp, indicesNode, axisNode, batch_dims); + } else if (opset_version == 1) { + gather = std::make_shared(quantizationOp, indicesNode, axisNode); + } else { + throw std::runtime_error("Unknown opset version"); + } + + ngraph::ResultVector results{ std::make_shared(gather) }; + return std::make_shared(results, ngraph::ParameterVector{ input }, "GatherFunction"); +} + +std::shared_ptr GatherFunction::getReference( + const ngraph::PartialShape& inputShape, + const std::vector& gatherIndicesShape, + const std::vector& gatherIndicesValues, + const std::vector& axis, + const int64_t batch_dims, + const ngraph::element::Type precisionBeforeDequantization, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore, + const ngraph::element::Type precisionAfterOperation, + const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter, + const int opset_version) { + const auto input = std::make_shared(precisionBeforeDequantization, inputShape); + + const std::shared_ptr quantizationOpBefore = makeDequantization(input, dequantizationBefore); + + const auto indicesNode = std::make_shared( + ngraph::element::i64, + ngraph::Shape(gatherIndicesShape), + gatherIndicesValues); + const auto axisNode = std::make_shared(ngraph::element::i64, ngraph::Shape{ axis.size() }, axis); + + std::shared_ptr gather; + if (opset_version == 7) { + gather = std::make_shared(quantizationOpBefore, indicesNode, axisNode, batch_dims); + } else if (opset_version == 8) { + gather = std::make_shared(quantizationOpBefore, indicesNode, axisNode, batch_dims); + } else if (opset_version == 1) { + gather = std::make_shared(quantizationOpBefore, indicesNode, axisNode); + } else { + throw std::runtime_error("Unknown opset version"); + } + + if (quantizationOpBefore->get_output_element_type(0) != precisionAfterOperation) { + THROW_IE_LPT_EXCEPTION(*quantizationOpBefore) << "unexpected precision '" << precisionAfterOperation << "' after operation"; + } + if (gather->get_output_element_type(0) != precisionAfterOperation) { + THROW_IE_LPT_EXCEPTION(*gather) << "unexpected precision '" << precisionAfterOperation << "' after operation"; + } + + const std::shared_ptr quantizationOpAfter = makeDequantization(gather, dequantizationAfter); + quantizationOpAfter->set_friendly_name("output"); + + ngraph::ResultVector results{ std::make_shared(quantizationOpAfter) }; + return std::make_shared(results, ngraph::ParameterVector{ input }, "GatherFunction"); +} + +} // namespace subgraph +} // namespace builder +} // namespace ngraph