Add gather lpt transformation (#14597)
* Add gather lpt transformation * Add per-channel gather lpt dequantization support * Fix review comments * Add GPU test case * Fix clang-format error gpu case build error * Fix comments * Fix clang-format check fail * Update docs * Fix comments * Add Gather opset1 quantization support
This commit is contained in:
parent
d86ba0742c
commit
9e83b081f4
@ -44,6 +44,7 @@
|
||||
<tab type="user" title="FakeQuantizeDecompositionTransformation" url="@ref openvino_docs_OV_UG_lpt_FakeQuantizeDecompositionTransformation"/>
|
||||
<tab type="user" title="FakeQuantizeTransformation" url="@ref openvino_docs_OV_UG_lpt_FakeQuantizeTransformation"/>
|
||||
<tab type="user" title="InterpolateTransformation" url="@ref openvino_docs_OV_UG_lpt_InterpolateTransformation"/>
|
||||
<tab type="user" title="GatherTransformation" url="@ref openvino_docs_OV_UG_lpt_GatherTransformation"/>
|
||||
<tab type="user" title="GroupConvolutionTransformation" url="@ref openvino_docs_OV_UG_lpt_GroupConvolutionTransformation"/>
|
||||
<tab type="user" title="MatMulTransformation" url="@ref openvino_docs_OV_UG_lpt_MatMulTransformation"/>
|
||||
<tab type="user" title="MaxPoolTransformation" url="@ref openvino_docs_OV_UG_lpt_MaxPoolTransformation"/>
|
||||
|
@ -60,6 +60,8 @@ LPT transformations propagate dequantization operations through the following op
|
||||
* [Squeeze-1](@ref openvino_docs_ops_shape_Reshape_1)
|
||||
* [StridedSlice-1](@ref openvino_docs_ops_movement_StridedSlice_1)
|
||||
* [Transpose-1](@ref openvino_docs_ops_movement_Transpose_1)
|
||||
* [Gather-7](@ref openvino_docs_ops_movement_Gather_7)
|
||||
* [Gather-8](@ref openvino_docs_ops_movement_Gather_8)
|
||||
* [Unsqueeze-1](@ref openvino_docs_ops_shape_Unsqueeze_1)
|
||||
* [VariadicSplit-1](@ref openvino_docs_ops_movement_VariadicSplit_1)
|
||||
|
||||
@ -149,6 +151,7 @@ This step has the most transformations. These transformations can be separated i
|
||||
* [FakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_FakeQuantizeTransformation)
|
||||
* [InterpolateTransformation](@ref openvino_docs_OV_UG_lpt_InterpolateTransformation)
|
||||
* [GroupConvolutionTransformation](@ref openvino_docs_OV_UG_lpt_GroupConvolutionTransformation)
|
||||
* [GatherTransformation](@ref openvino_docs_OV_UG_lpt_GatherTransformation)
|
||||
* [MatMulTransformation](@ref openvino_docs_OV_UG_lpt_MatMulTransformation)
|
||||
* [MaxPoolTransformation](@ref openvino_docs_OV_UG_lpt_MaxPoolTransformation)
|
||||
* [MultiplyTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyTransformation)
|
||||
|
@ -12,6 +12,7 @@ Main transformations are the majority of low precision transformations. Transfor
|
||||
* [FakeQuantizeTransformation](@ref openvino_docs_OV_UG_lpt_FakeQuantizeTransformation)
|
||||
* [InterpolateTransformation](@ref openvino_docs_OV_UG_lpt_InterpolateTransformation)
|
||||
* [GroupConvolutionTransformation](@ref openvino_docs_OV_UG_lpt_GroupConvolutionTransformation)
|
||||
* [GatherTransformation](@ref openvino_docs_OV_UG_lpt_GatherTransformation)
|
||||
* [MatMulTransformation](@ref openvino_docs_OV_UG_lpt_MatMulTransformation)
|
||||
* [MaxPoolTransformation](@ref openvino_docs_OV_UG_lpt_MaxPoolTransformation)
|
||||
* [MultiplyTransformation](@ref openvino_docs_OV_UG_lpt_MultiplyTransformation)
|
||||
|
@ -0,0 +1,3 @@
|
||||
# GatherTransformation transformation {#openvino_docs_OV_UG_lpt_GatherTransformation}
|
||||
|
||||
ngraph::pass::low_precision::GatherTransformation class represents the `Gather` operation transformation.
|
@ -23,6 +23,7 @@ openvino_docs_OV_UG_lpt_fusemultiplytofakequantizetransformation.rst
|
||||
openvino_docs_OV_UG_lpt_fusesubtracttofakequantizetransformation.rst
|
||||
openvino_docs_OV_UG_lpt_groupconvolutiontransformation.rst
|
||||
openvino_docs_OV_UG_lpt_interpolatetransformation.rst
|
||||
openvino_docs_OV_UG_lpt_gathertransformation.rst
|
||||
openvino_docs_OV_UG_lpt_linopsequencefusion.rst
|
||||
openvino_docs_OV_UG_lpt_mvntransformation.rst
|
||||
openvino_docs_OV_UG_lpt_markupavgpoolprecisionpreserved.rst
|
||||
|
@ -0,0 +1,25 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <algorithm>
|
||||
#include "low_precision/layer_transformation.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
namespace low_precision {
|
||||
|
||||
class LP_TRANSFORMATIONS_API GatherTransformation : public LayerTransformation {
|
||||
public:
|
||||
OPENVINO_RTTI("GatherTransformation", "0");
|
||||
GatherTransformation(const Params& params = Params());
|
||||
bool transform(TransformationContext& context, ngraph::pattern::Matcher &m) override;
|
||||
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const override;
|
||||
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> layer) const override;
|
||||
};
|
||||
|
||||
} // namespace low_precision
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
194
src/common/low_precision_transformations/src/gather.cpp
Normal file
194
src/common/low_precision_transformations/src/gather.cpp
Normal file
@ -0,0 +1,194 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "low_precision/gather.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include <ngraph/pattern/op/or.hpp>
|
||||
#include <ngraph/pattern/op/wrap_type.hpp>
|
||||
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "low_precision/rt_info/precision_preserved_attribute.hpp"
|
||||
#include "itt.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
namespace low_precision {
|
||||
|
||||
std::shared_ptr<opset1::Constant> gatherDeqConstant(
|
||||
const std::shared_ptr<ngraph::Node> &gather,
|
||||
const std::shared_ptr<ngraph::Node> &dequantizationConstant) {
|
||||
auto constant = ov::as_type_ptr<ngraph::opset1::Constant>(dequantizationConstant);
|
||||
auto constantShape = constant->get_shape();
|
||||
if (shape_size(constantShape) == 1ul) {
|
||||
return NetworkHelper::toScalar(constant);
|
||||
}
|
||||
|
||||
const auto rank = gather->get_input_partial_shape(0).size();
|
||||
if (rank != constantShape.size()) {
|
||||
// case when constShape without batch
|
||||
while ((constantShape.size() > 1) && (constantShape.size() < rank)) {
|
||||
constantShape.insert(constantShape.begin(), 1);
|
||||
}
|
||||
const auto newConstant = fold<ngraph::opset1::Broadcast>(
|
||||
constant,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, { constantShape.size() }, constantShape));
|
||||
constant = ov::as_type_ptr<ngraph::opset1::Constant>(newConstant);
|
||||
}
|
||||
|
||||
const int64_t axis = ov::as_type_ptr<opset1::Constant>(gather->get_input_node_shared_ptr(2))->cast_vector<int64_t>()[0];
|
||||
const size_t normalizedAxis = normalize_axis(gather->get_friendly_name(), axis, gather->get_input_partial_shape(0).rank());
|
||||
|
||||
// Dequantization channel matches with gather axis
|
||||
if (constantShape[normalizedAxis] != 1ul) {
|
||||
const auto gather1 = ov::as_type_ptr<ngraph::opset1::Gather>(gather);
|
||||
if (gather1) {
|
||||
const auto output = fold<ngraph::opset1::Gather>(
|
||||
constant,
|
||||
gather1->input_value(1),
|
||||
gather1->input_value(2));
|
||||
constant = ov::as_type_ptr<opset1::Constant>(NetworkHelper::toScalarIfPossible(output));
|
||||
}
|
||||
|
||||
const auto gather7 = ov::as_type_ptr<ngraph::opset7::Gather>(gather);
|
||||
if (gather7) {
|
||||
const auto output = fold<ngraph::opset7::Gather>(
|
||||
constant,
|
||||
gather7->input_value(1),
|
||||
gather7->input_value(2),
|
||||
gather7->get_batch_dims());
|
||||
constant = ov::as_type_ptr<opset1::Constant>(NetworkHelper::toScalarIfPossible(output));
|
||||
}
|
||||
|
||||
const auto gather8 = ov::as_type_ptr<ngraph::opset8::Gather>(gather);
|
||||
if (gather8) {
|
||||
const auto output = fold<ngraph::opset8::Gather>(
|
||||
constant,
|
||||
gather8->input_value(1),
|
||||
gather8->input_value(2),
|
||||
gather8->get_batch_dims());
|
||||
constant = ov::as_type_ptr<opset1::Constant>(NetworkHelper::toScalarIfPossible(output));
|
||||
}
|
||||
}
|
||||
return constant;
|
||||
}
|
||||
|
||||
GatherTransformation::GatherTransformation(const Params& params) : LayerTransformation(params) {
|
||||
MATCHER_SCOPE(GatherTransformation);
|
||||
auto gather = pattern::wrap_type<opset1::Gather, opset7::Gather, opset8::Gather>({ pattern::wrap_type<opset1::Multiply>(),
|
||||
pattern::any_input(),
|
||||
pattern::any_input() });
|
||||
|
||||
ngraph::graph_rewrite_callback callback = [this](pattern::Matcher& m) {
|
||||
auto op = m.get_match_root();
|
||||
if (transformation_callback(op)) {
|
||||
return false;
|
||||
}
|
||||
return transform(*context, m);
|
||||
};
|
||||
|
||||
auto m = std::make_shared<ngraph::pattern::Matcher>(gather, matcher_name);
|
||||
this->register_matcher(m, callback);
|
||||
}
|
||||
|
||||
bool GatherTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher &m) {
|
||||
auto node = m.get_match_root();
|
||||
if (!canBeTransformed(context, m.get_match_root())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const std::shared_ptr<Node> gather = NetworkHelper::separateInStandaloneBranch(m.get_match_root(), defaultPrecisions);
|
||||
FakeQuantizeDequantization dequantization = NetworkHelper::getDequantization(gather, defaultPrecisions);
|
||||
|
||||
if (dequantization.multiply != nullptr) {
|
||||
const auto newConstant = gatherDeqConstant(gather, dequantization.multiplyConstant);
|
||||
replace_node(dequantization.multiplyConstant, newConstant);
|
||||
}
|
||||
if (dequantization.subtract != nullptr) {
|
||||
const auto newConstant = gatherDeqConstant(gather, dequantization.subtractConstant);
|
||||
replace_node(dequantization.subtractConstant, newConstant);
|
||||
}
|
||||
|
||||
moveDequantizationAfter(context, gather, NetworkHelper::getDequantization(gather, defaultPrecisions), false);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GatherTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
|
||||
if (!LayerTransformation::canBeTransformed(context, operation)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto dequantization = NetworkHelper::getDequantization(operation, defaultPrecisions);
|
||||
if (dequantization.empty()) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto isScalar = [&] {
|
||||
if (dequantization.multiply != nullptr) {
|
||||
if (!NetworkHelper::isScalarLike(dequantization.multiplyConstant)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (dequantization.subtract != nullptr) {
|
||||
if (!NetworkHelper::isScalarLike(dequantization.subtractConstant)) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
}();
|
||||
if (isScalar) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// If dequantization constant is not scalar, Gather axis must be constant.
|
||||
// If the Gather axis matches with dequantization channel, the Gather indices
|
||||
// must be constant and have 0D or 1D shape so we can do folding.
|
||||
const auto axisConstant = ov::as_type_ptr<opset1::Constant>(operation->get_input_node_shared_ptr(2));
|
||||
if (axisConstant == nullptr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (operation->get_input_partial_shape(0).rank().is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
const auto canBeFolded = [&](const std::shared_ptr<ngraph::Node> dequantizationConstant) {
|
||||
auto constantShape = dequantizationConstant->get_shape();
|
||||
const auto rank = operation->get_input_partial_shape(0).size();
|
||||
if (rank != constantShape.size()) {
|
||||
while ((constantShape.size() > 1) && (constantShape.size() < rank)) {
|
||||
constantShape.insert(constantShape.begin(), 1);
|
||||
}
|
||||
}
|
||||
const int64_t axis = axisConstant->cast_vector<int64_t>()[0];
|
||||
const size_t normalizedAxis = normalize_axis(operation->get_friendly_name(), axis, operation->get_input_partial_shape(0).rank());
|
||||
if (constantShape[normalizedAxis] != 1ul) {
|
||||
const auto indicesConstant = ov::as_type_ptr<opset1::Constant>(operation->get_input_node_shared_ptr(1));
|
||||
if (indicesConstant == nullptr)
|
||||
return false;
|
||||
const auto indicesShape = indicesConstant->get_shape();
|
||||
if (indicesShape.size() != 0 && indicesShape.size() != 1) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
return true;
|
||||
};
|
||||
|
||||
if ((dequantization.multiply && !canBeFolded(dequantization.multiplyConstant)) ||
|
||||
(dequantization.subtract && !canBeFolded(dequantization.subtractConstant))) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
|
||||
bool GatherTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const {
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace low_precision
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
@ -69,6 +69,7 @@
|
||||
#include "low_precision/shuffle_channels.hpp"
|
||||
#include "low_precision/strided_slice.hpp"
|
||||
#include "low_precision/transpose.hpp"
|
||||
#include "low_precision/gather.hpp"
|
||||
#include "low_precision/unsqueeze.hpp"
|
||||
#include "low_precision/variadic_split.hpp"
|
||||
#include "low_precision/move_fake_quantize.hpp"
|
||||
@ -246,6 +247,7 @@ bool ngraph::pass::low_precision::LowPrecision::run_on_model(const std::shared_p
|
||||
ADD_MATCHER(common, SplitTransformation, params)
|
||||
ADD_MATCHER(common, StridedSliceTransformation, params)
|
||||
ADD_MATCHER(common, TransposeTransformation, params)
|
||||
ADD_MATCHER(common, GatherTransformation, params)
|
||||
ADD_MATCHER(common, UnsqueezeTransformation, params)
|
||||
ADD_MATCHER(common, VariadicSplitTransformation, params)
|
||||
|
||||
|
@ -0,0 +1,275 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <low_precision/gather.hpp>
|
||||
#include <memory>
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "layer_transformation.hpp"
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
#include "lpt_ngraph_functions/gather_function.hpp"
|
||||
#include "simple_low_precision_transformer.hpp"
|
||||
|
||||
namespace {
|
||||
using namespace testing;
|
||||
using namespace ngraph::pass;
|
||||
using namespace ngraph;
|
||||
|
||||
class GatherTransformationTestValues {
|
||||
public:
|
||||
class Actual {
|
||||
public:
|
||||
ngraph::element::Type precisionBeforeDequantization;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
||||
};
|
||||
|
||||
class Expected {
|
||||
public:
|
||||
ngraph::element::Type precisionBeforeDequantization;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
|
||||
ngraph::element::Type precisionAfterOperation;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
|
||||
};
|
||||
|
||||
std::vector<size_t> gatherIndicesShape;
|
||||
std::vector<int> gatherIndicesValues;
|
||||
std::vector<int> axis;
|
||||
int64_t batch_dims;
|
||||
TestTransformationParams params;
|
||||
Actual actual;
|
||||
Expected expected;
|
||||
};
|
||||
|
||||
typedef std::tuple<ngraph::PartialShape, GatherTransformationTestValues, int> GatherTransformationParams;
|
||||
|
||||
class GatherTransformation : public LayerTransformation,
|
||||
public testing::WithParamInterface<GatherTransformationParams> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const ngraph::PartialShape inputShape = std::get<0>(GetParam());
|
||||
const GatherTransformationTestValues testValues = std::get<1>(GetParam());
|
||||
const int opset_version = std::get<2>(GetParam());
|
||||
|
||||
actualFunction =
|
||||
ngraph::builder::subgraph::GatherFunction::getOriginal(inputShape,
|
||||
testValues.gatherIndicesShape,
|
||||
testValues.gatherIndicesValues,
|
||||
testValues.axis,
|
||||
testValues.batch_dims,
|
||||
testValues.actual.precisionBeforeDequantization,
|
||||
testValues.actual.dequantization,
|
||||
opset_version);
|
||||
|
||||
SimpleLowPrecisionTransformer transformer;
|
||||
transformer.add<ngraph::pass::low_precision::GatherTransformation, ngraph::opset1::Gather>(testValues.params);
|
||||
transformer.transform(actualFunction);
|
||||
|
||||
referenceFunction =
|
||||
ngraph::builder::subgraph::GatherFunction::getReference(inputShape,
|
||||
testValues.gatherIndicesShape,
|
||||
testValues.gatherIndicesValues,
|
||||
testValues.axis,
|
||||
testValues.batch_dims,
|
||||
testValues.expected.precisionBeforeDequantization,
|
||||
testValues.expected.dequantizationBefore,
|
||||
testValues.expected.precisionAfterOperation,
|
||||
testValues.expected.dequantizationAfter,
|
||||
opset_version);
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<GatherTransformationParams> obj) {
|
||||
const ngraph::PartialShape inputShape = std::get<0>(obj.param);
|
||||
const GatherTransformationTestValues testValues = std::get<1>(obj.param);
|
||||
const int opset_version = std::get<2>(obj.param);
|
||||
|
||||
std::ostringstream result;
|
||||
result << "_" << inputShape << "_" << testValues.gatherIndicesShape << "_" << testValues.gatherIndicesValues
|
||||
<< "_" << testValues.axis << "_" << testValues.batch_dims << "_"
|
||||
<< testValues.actual.precisionBeforeDequantization << "_" << testValues.actual.dequantization << "_"
|
||||
<< testValues.expected.dequantizationBefore << "_" << opset_version;
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(GatherTransformation, CompareFunctions) {
|
||||
ov::pass::InitNodeInfo().run_on_model(actualFunction);
|
||||
actualFunction->validate_nodes_and_infer_types();
|
||||
auto res = compare_functions(actualFunction, referenceFunction, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
|
||||
ASSERT_TRUE(LayerTransformation::allNamesAreUnique(actualFunction)) << "Not all names are unique";
|
||||
}
|
||||
|
||||
namespace testValues1 {
|
||||
const std::vector<int> opset_version = {1, 7, 8};
|
||||
|
||||
const std::vector<ngraph::PartialShape> inputShapes3D = {{3, 3, 4}, {-1, -1, -1}};
|
||||
|
||||
const std::vector<GatherTransformationTestValues> testValues = {
|
||||
// U8: per-tensor quantization
|
||||
{{1},
|
||||
{0},
|
||||
{0},
|
||||
std::int64_t{0},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{128}, ngraph::element::f32, {}, true, 1, ngraph::element::u8, true}, {0.1f}}},
|
||||
{ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{128}, ngraph::element::f32, {}, true, 1, ngraph::element::u8, true}, {0.1f}}}},
|
||||
// U8: per-tensor quantization
|
||||
{{2},
|
||||
{0, 1},
|
||||
{0},
|
||||
std::int64_t{0},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}},
|
||||
{ngraph::element::u8, {{}, {}, {}}, ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}},
|
||||
// U8: per-tensor quantization
|
||||
{{3, 2},
|
||||
{1, 2, 1, 2, 1, 2},
|
||||
{1},
|
||||
std::int64_t{1},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}},
|
||||
{ngraph::element::u8, {{}, {}, {}}, ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}},
|
||||
// U8: per-channel quantization with the same values
|
||||
{{1},
|
||||
{0},
|
||||
{0},
|
||||
std::int64_t{0},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ngraph::element::u8,
|
||||
{{ngraph::element::f32},
|
||||
{{128.f}, element::undefined, {1, 3, 1}, false, 1ul, element::u8, true},
|
||||
{{0.1}, ngraph::element::f32, {1, 3, 1}}}},
|
||||
{ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32},
|
||||
{{128.f}, element::undefined, {1, 3, 1}, false, 1ul, element::u8, true},
|
||||
{{0.1}, ngraph::element::f32, {1, 3, 1}}}}},
|
||||
// U8: per-channel quantization, gather axis match with channel
|
||||
{{1},
|
||||
{0},
|
||||
{1}, // axis
|
||||
std::int64_t{0},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ngraph::element::u8,
|
||||
{{ngraph::element::f32},
|
||||
{{128, 64, 32}, ngraph::element::f32, {1, 3, 1}},
|
||||
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}},
|
||||
{ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{128}, ngraph::element::f32, {}}, {{0.3f}, ngraph::element::f32, {}}}}},
|
||||
// U8: per-channel quantization, gather axis match with channel, quantization constant shape size is
|
||||
// less than input shape
|
||||
{{1},
|
||||
{1},
|
||||
{1}, // axis
|
||||
std::int64_t{0},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ngraph::element::u8,
|
||||
{{ngraph::element::f32},
|
||||
{{128, 64, 32}, ngraph::element::f32, {3, 1}},
|
||||
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {3, 1}}}},
|
||||
{ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{64}, ngraph::element::f32, {}}, {{0.2f}, ngraph::element::f32, {}}}}},
|
||||
// U8: per-channel quantization, gather axis and channel doesn't match
|
||||
{{1},
|
||||
{0},
|
||||
{0},
|
||||
std::int64_t{0},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ngraph::element::u8,
|
||||
{{ngraph::element::f32},
|
||||
{{128, 64, 32}, ngraph::element::f32, {1, 3, 1}},
|
||||
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}},
|
||||
{ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32},
|
||||
{{128, 64, 32}, ngraph::element::f32, {1, 3, 1}},
|
||||
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}}},
|
||||
// U8: per-channel quantization, negative axis, gather axis match with channel
|
||||
{{1},
|
||||
{0},
|
||||
{-2}, // axis
|
||||
std::int64_t{0},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ngraph::element::u8,
|
||||
{{ngraph::element::f32},
|
||||
{{128, 64, 32}, ngraph::element::f32, {1, 3, 1}},
|
||||
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}},
|
||||
{ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{128}, ngraph::element::f32, {}}, {{0.3f}, ngraph::element::f32, {}}}}},
|
||||
// empty
|
||||
{{1},
|
||||
{0},
|
||||
{0},
|
||||
std::int64_t{0},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ngraph::element::u8, {}},
|
||||
{ngraph::element::u8, {}, ngraph::element::u8, {}}},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_LPT,
|
||||
GatherTransformation,
|
||||
::testing::Combine(::testing::ValuesIn(inputShapes3D),
|
||||
::testing::ValuesIn(testValues),
|
||||
::testing::ValuesIn(opset_version)),
|
||||
GatherTransformation::getTestCaseName);
|
||||
} // namespace testValues1
|
||||
|
||||
namespace testValues2 {
|
||||
const std::vector<int> opset_version = {8};
|
||||
|
||||
const std::vector<ngraph::PartialShape> inputShapes3D = {{3, 3, 4}, {-1, -1, -1}};
|
||||
|
||||
const std::vector<GatherTransformationTestValues> testValues = {
|
||||
// U8: per-tensor quantization, negative indices value
|
||||
{{3, 2},
|
||||
{-2, 2, -2, 2, -2, 2}, // indices value
|
||||
{1},
|
||||
std::int64_t{1},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}},
|
||||
{ngraph::element::u8, {{}, {}, {}}, ngraph::element::u8, {{ngraph::element::f32}, {128}, {0.1f}}}},
|
||||
// U8: per-channel quantization, negative indices value, gather axis match with channel
|
||||
{{1},
|
||||
{-1}, // indices value
|
||||
{1}, // axis
|
||||
std::int64_t{0},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
{ngraph::element::u8,
|
||||
{{ngraph::element::f32},
|
||||
{{128, 64, 32}, ngraph::element::f32, {1, 3, 1}},
|
||||
{{0.3f, 0.2f, 0.1f}, ngraph::element::f32, {1, 3, 1}}}},
|
||||
{ngraph::element::u8,
|
||||
{{}, {}, {}},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{32}, ngraph::element::f32, {}}, {{0.1f}, ngraph::element::f32, {}}}}},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_LPT,
|
||||
GatherTransformation,
|
||||
::testing::Combine(::testing::ValuesIn(inputShapes3D),
|
||||
::testing::ValuesIn(testValues),
|
||||
::testing::ValuesIn(opset_version)),
|
||||
GatherTransformation::getTestCaseName);
|
||||
} // namespace testValues2
|
||||
|
||||
} // namespace
|
@ -0,0 +1,89 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "low_precision_transformations/gather_transformation.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
const std::vector<ngraph::element::Type> precisions = {
|
||||
ngraph::element::f32,
|
||||
};
|
||||
|
||||
const std::vector<int> opset_version = {
|
||||
1, 7, 8
|
||||
};
|
||||
|
||||
const std::vector<GatherTransformationTestValues> testValues = {
|
||||
// U8: per-tensor quantization
|
||||
{
|
||||
{3, 3, 4},
|
||||
{1},
|
||||
{0},
|
||||
{0},
|
||||
std::int64_t{0},
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
|
||||
ngraph::element::f32,
|
||||
{256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}}
|
||||
},
|
||||
// U8: per-channel quantization
|
||||
{
|
||||
{1, 3, 5},
|
||||
{1},
|
||||
{0},
|
||||
{0},
|
||||
std::int64_t{0},
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
|
||||
ngraph::element::f32,
|
||||
{
|
||||
256,
|
||||
{1, 3, 1},
|
||||
{0.f, 0.f, 0.f},
|
||||
{25.5f, 25.5f, 25.5f},
|
||||
{0.f, 12.5f, 25.5f},
|
||||
{25.5f, 25.5f + 12.5f * 2, 25.5f + 12.5f * 4}
|
||||
}
|
||||
},
|
||||
// U8: per-channel quantization, axis match with dequantization channel, dequantization constant shape is less than gather input shape
|
||||
{
|
||||
{1, 3, 4},
|
||||
{1},
|
||||
{0},
|
||||
{1},
|
||||
std::int64_t{0},
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
|
||||
ngraph::element::f32,
|
||||
{
|
||||
256,
|
||||
{3, 1},
|
||||
{0.f, 0.f, 0.f},
|
||||
{25.5f, 25.5f, 25.5f},
|
||||
{0.f, 12.5f, 25.5f},
|
||||
{25.5f, 25.5f + 12.5f * 2, 25.5f + 12.5f * 4}
|
||||
}
|
||||
},
|
||||
// 4D
|
||||
{
|
||||
{3, 4, 100, 2},
|
||||
{2},
|
||||
{1, 2},
|
||||
{0},
|
||||
std::int64_t{0},
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
|
||||
ngraph::element::f32,
|
||||
{256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}}
|
||||
},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_LPT, GatherTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::ValuesIn(testValues),
|
||||
::testing::ValuesIn(opset_version)),
|
||||
GatherTransformation::getTestCaseName);
|
||||
} // namespace
|
@ -0,0 +1,72 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include "low_precision_transformations/gather_transformation.hpp"
|
||||
#include "common_test_utils/test_constants.hpp"
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
const std::vector<ngraph::element::Type> precisions = {
|
||||
ngraph::element::f32,
|
||||
ngraph::element::f16
|
||||
};
|
||||
|
||||
const std::vector<int> opset_version = {
|
||||
1, 7, 8
|
||||
};
|
||||
|
||||
const std::vector<GatherTransformationTestValues> testValues = {
|
||||
// U8: per-tensor quantization
|
||||
{
|
||||
{3, 3, 4},
|
||||
{1},
|
||||
{0},
|
||||
{0},
|
||||
std::int64_t{0},
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
|
||||
ngraph::element::f32,
|
||||
{256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}}
|
||||
},
|
||||
// U8: per-channel quantization
|
||||
{
|
||||
{1, 3, 5},
|
||||
{1},
|
||||
{0},
|
||||
{0},
|
||||
std::int64_t{0},
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
|
||||
ngraph::element::f32,
|
||||
{
|
||||
256,
|
||||
{1, 3, 1},
|
||||
{0.f, 0.f, 0.f},
|
||||
{25.5f, 25.5f, 25.5f},
|
||||
{0.f, 12.5f, 25.5f},
|
||||
{25.5f, 25.5f + 12.5f * 2, 25.5f + 12.5f * 4}
|
||||
}
|
||||
},
|
||||
// 4D
|
||||
{
|
||||
{3, 4, 100, 2},
|
||||
{2},
|
||||
{1, 2},
|
||||
{0},
|
||||
std::int64_t{0},
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8(),
|
||||
ngraph::element::f32,
|
||||
{256, {}, {0.f}, {25.5f}, {12.5f}, {25.5f + 12.5f}}
|
||||
},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(smoke_LPT, GatherTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(precisions),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::ValuesIn(testValues),
|
||||
::testing::ValuesIn(opset_version)),
|
||||
GatherTransformation::getTestCaseName);
|
||||
} // namespace
|
@ -0,0 +1,43 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include <memory>
|
||||
|
||||
#include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp"
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
class GatherTransformationTestValues {
|
||||
public:
|
||||
ngraph::PartialShape inputShape;
|
||||
std::vector<size_t> gatherIndicesShape;
|
||||
std::vector<int> gatherIndicesValues;
|
||||
std::vector<int> axis;
|
||||
int64_t batch_dims;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
ngraph::element::Type precisionBeforeFq;
|
||||
ngraph::builder::subgraph::FakeQuantizeOnData fqOnData;
|
||||
};
|
||||
|
||||
typedef std::tuple<
|
||||
ngraph::element::Type,
|
||||
std::string,
|
||||
GatherTransformationTestValues,
|
||||
int> GatherTransformationParams;
|
||||
|
||||
class GatherTransformation :
|
||||
public testing::WithParamInterface<GatherTransformationParams>,
|
||||
public LayerTestsUtils::LayerTransformation {
|
||||
public:
|
||||
static std::string getTestCaseName(const testing::TestParamInfo<GatherTransformationParams>& obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,56 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "low_precision_transformations/gather_transformation.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <ie_core.hpp>
|
||||
|
||||
#include <transformations/init_node_info.hpp>
|
||||
#include "lpt_ngraph_functions/gather_function.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
std::string GatherTransformation::getTestCaseName(const testing::TestParamInfo<GatherTransformationParams>& obj) {
|
||||
ngraph::element::Type precision;
|
||||
std::string targetDevice;
|
||||
GatherTransformationTestValues testValues;
|
||||
int opset_version;
|
||||
std::tie(precision, targetDevice, testValues, opset_version) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result <<
|
||||
precision << "_" <<
|
||||
targetDevice << "_" <<
|
||||
testValues.inputShape << "_" <<
|
||||
opset_version;
|
||||
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void GatherTransformation::SetUp() {
|
||||
ngraph::element::Type precision;
|
||||
GatherTransformationTestValues testValues;
|
||||
int opset_version;
|
||||
std::tie(precision, targetDevice, testValues, opset_version) = this->GetParam();
|
||||
|
||||
function = ngraph::builder::subgraph::GatherFunction::getOriginal(
|
||||
testValues.inputShape,
|
||||
testValues.gatherIndicesShape,
|
||||
testValues.gatherIndicesValues,
|
||||
testValues.axis,
|
||||
testValues.batch_dims,
|
||||
testValues.precisionBeforeFq,
|
||||
testValues.fqOnData,
|
||||
opset_version);
|
||||
}
|
||||
|
||||
TEST_P(GatherTransformation, CompareWithRefImpl) {
|
||||
Run();
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,54 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
class GatherFunction {
|
||||
public:
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::PartialShape& inputShape,
|
||||
const std::vector<size_t>& gatherIndicesShape,
|
||||
const std::vector<int>& gatherIndicesValues,
|
||||
const std::vector<int>& axis,
|
||||
const int64_t batch_dims,
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||
const int opset_version);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::PartialShape& inputShape,
|
||||
const std::vector<size_t>& gatherIndicesShape,
|
||||
const std::vector<int>& gatherIndicesValues,
|
||||
const std::vector<int>& axis,
|
||||
const int64_t batch_dims,
|
||||
const ngraph::element::Type precisionBeforeFq,
|
||||
const FakeQuantizeOnData& fqOnData,
|
||||
const int opset_version);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getReference(
|
||||
const ngraph::PartialShape& inputShape,
|
||||
const std::vector<size_t>& gatherIndicesShape,
|
||||
const std::vector<int>& gatherIndicesValues,
|
||||
const std::vector<int>& axis,
|
||||
const int64_t batch_dims,
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
|
||||
const ngraph::element::Type precisionAfterOperation,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter,
|
||||
const int opset_version);
|
||||
};
|
||||
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
@ -0,0 +1,133 @@
|
||||
// Copyright (C) 2018-2022 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "lpt_ngraph_functions/gather_function.hpp"
|
||||
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset7.hpp>
|
||||
#include <ngraph/opsets/opset8.hpp>
|
||||
#include "lpt_ngraph_functions/common/builders.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
std::shared_ptr<ngraph::Function> GatherFunction::getOriginal(
|
||||
const ngraph::PartialShape& inputShape,
|
||||
const std::vector<size_t>& gatherIndicesShape,
|
||||
const std::vector<int>& gatherIndicesValues,
|
||||
const std::vector<int>& axis,
|
||||
const int64_t batch_dims,
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||
const int opset_version) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, inputShape);
|
||||
const std::shared_ptr<Node> dequantizationOp = makeDequantization(input, dequantization);
|
||||
const auto indicesNode = std::make_shared<ngraph::opset1::Constant>(
|
||||
ngraph::element::i64,
|
||||
ngraph::Shape(gatherIndicesShape),
|
||||
gatherIndicesValues);
|
||||
const auto axisNode = std::make_shared<ngraph::op::Constant>(ngraph::element::i64, ngraph::Shape{ axis.size() }, axis);
|
||||
std::shared_ptr<Node> gather;
|
||||
if (opset_version == 7) {
|
||||
gather = std::make_shared<ngraph::opset7::Gather>(dequantizationOp, indicesNode, axisNode, batch_dims);
|
||||
} else if (opset_version == 8) {
|
||||
gather = std::make_shared<ngraph::opset8::Gather>(dequantizationOp, indicesNode, axisNode, batch_dims);
|
||||
} else if (opset_version == 1) {
|
||||
gather = std::make_shared<ngraph::opset1::Gather>(dequantizationOp, indicesNode, axisNode);
|
||||
} else {
|
||||
throw std::runtime_error("Unknown opset version");
|
||||
}
|
||||
gather->set_friendly_name("output");
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(gather) };
|
||||
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "GatherFunction");
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> GatherFunction::getOriginal(
|
||||
const ngraph::PartialShape& inputShape,
|
||||
const std::vector<size_t>& gatherIndicesShape,
|
||||
const std::vector<int>& gatherIndicesValues,
|
||||
const std::vector<int>& axis,
|
||||
const int64_t batch_dims,
|
||||
const ngraph::element::Type precisionBeforeFq,
|
||||
const FakeQuantizeOnData& fqOnData,
|
||||
const int opset_version) {
|
||||
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeFq, inputShape);
|
||||
|
||||
const std::shared_ptr<Node> quantizationOp = fqOnData.empty() ?
|
||||
std::dynamic_pointer_cast<ngraph::Node>(input) :
|
||||
makeFakeQuantize(input, precisionBeforeFq, fqOnData);
|
||||
|
||||
const auto indicesNode = std::make_shared<ngraph::opset1::Constant>(
|
||||
ngraph::element::i64,
|
||||
ngraph::Shape(gatherIndicesShape),
|
||||
gatherIndicesValues);
|
||||
const auto axisNode = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{ axis.size() }, axis);
|
||||
|
||||
std::shared_ptr<Node> gather;
|
||||
if (opset_version == 7) {
|
||||
gather = std::make_shared<ngraph::opset7::Gather>(quantizationOp, indicesNode, axisNode, batch_dims);
|
||||
} else if (opset_version == 8) {
|
||||
gather = std::make_shared<ngraph::opset8::Gather>(quantizationOp, indicesNode, axisNode, batch_dims);
|
||||
} else if (opset_version == 1) {
|
||||
gather = std::make_shared<ngraph::opset1::Gather>(quantizationOp, indicesNode, axisNode);
|
||||
} else {
|
||||
throw std::runtime_error("Unknown opset version");
|
||||
}
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(gather) };
|
||||
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "GatherFunction");
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> GatherFunction::getReference(
|
||||
const ngraph::PartialShape& inputShape,
|
||||
const std::vector<size_t>& gatherIndicesShape,
|
||||
const std::vector<int>& gatherIndicesValues,
|
||||
const std::vector<int>& axis,
|
||||
const int64_t batch_dims,
|
||||
const ngraph::element::Type precisionBeforeDequantization,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
|
||||
const ngraph::element::Type precisionAfterOperation,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter,
|
||||
const int opset_version) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(precisionBeforeDequantization, inputShape);
|
||||
|
||||
const std::shared_ptr<Node> quantizationOpBefore = makeDequantization(input, dequantizationBefore);
|
||||
|
||||
const auto indicesNode = std::make_shared<ngraph::opset1::Constant>(
|
||||
ngraph::element::i64,
|
||||
ngraph::Shape(gatherIndicesShape),
|
||||
gatherIndicesValues);
|
||||
const auto axisNode = std::make_shared<ngraph::opset1::Constant>(ngraph::element::i64, ngraph::Shape{ axis.size() }, axis);
|
||||
|
||||
std::shared_ptr<Node> gather;
|
||||
if (opset_version == 7) {
|
||||
gather = std::make_shared<ngraph::opset7::Gather>(quantizationOpBefore, indicesNode, axisNode, batch_dims);
|
||||
} else if (opset_version == 8) {
|
||||
gather = std::make_shared<ngraph::opset8::Gather>(quantizationOpBefore, indicesNode, axisNode, batch_dims);
|
||||
} else if (opset_version == 1) {
|
||||
gather = std::make_shared<ngraph::opset1::Gather>(quantizationOpBefore, indicesNode, axisNode);
|
||||
} else {
|
||||
throw std::runtime_error("Unknown opset version");
|
||||
}
|
||||
|
||||
if (quantizationOpBefore->get_output_element_type(0) != precisionAfterOperation) {
|
||||
THROW_IE_LPT_EXCEPTION(*quantizationOpBefore) << "unexpected precision '" << precisionAfterOperation << "' after operation";
|
||||
}
|
||||
if (gather->get_output_element_type(0) != precisionAfterOperation) {
|
||||
THROW_IE_LPT_EXCEPTION(*gather) << "unexpected precision '" << precisionAfterOperation << "' after operation";
|
||||
}
|
||||
|
||||
const std::shared_ptr<Node> quantizationOpAfter = makeDequantization(gather, dequantizationAfter);
|
||||
quantizationOpAfter->set_friendly_name("output");
|
||||
|
||||
ngraph::ResultVector results{ std::make_shared<ngraph::opset1::Result>(quantizationOpAfter) };
|
||||
return std::make_shared<ngraph::Function>(results, ngraph::ParameterVector{ input }, "GatherFunction");
|
||||
}
|
||||
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
Loading…
Reference in New Issue
Block a user