[LPT] StridedSlice Transformation (#3817)
* [nGraph] evaluate_strided_slice: replace read_vec to host_tensor_2_vec * [LPT] StridedSliceTransformation
This commit is contained in:
parent
bad4e97d9b
commit
ef72e21213
@ -0,0 +1,26 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include "layer_transformation.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
namespace low_precision {
|
||||
|
||||
class TRANSFORMATIONS_API StridedSliceTransformation : public LayerTransformation {
|
||||
public:
|
||||
StridedSliceTransformation(const Params& params);
|
||||
void registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const override;
|
||||
bool transform(TransformationContext& context, ngraph::pattern::Matcher& m) const override;
|
||||
bool canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> op) const override;
|
||||
bool isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept override;
|
||||
};
|
||||
|
||||
} // namespace low_precision
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
@ -0,0 +1,105 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "low_precision/strided_slice.hpp"
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
|
||||
#include "low_precision/network_helper.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace pass {
|
||||
namespace low_precision {
|
||||
|
||||
std::shared_ptr<Node> stridedSliceDeqConstant(
|
||||
const std::shared_ptr<ngraph::Node> strSlice,
|
||||
const std::shared_ptr<ngraph::Node> dequantizaitonConstant) {
|
||||
auto constant = as_type_ptr<ngraph::opset1::Constant>(dequantizaitonConstant);
|
||||
if (NetworkHelper::isScalarLike(constant)) {
|
||||
return NetworkHelper::toScalar(constant);
|
||||
}
|
||||
|
||||
if (strSlice->get_input_shape(0).size() != constant->get_shape().size()) {
|
||||
const auto constantShape = constant->get_shape();
|
||||
const auto stridedSliceShape = strSlice->get_input_shape(0);
|
||||
ngraph::Shape newConstantShape(stridedSliceShape.size(), 1);
|
||||
|
||||
for (size_t i = 0; i < constantShape.size(); ++i) {
|
||||
if (constantShape[i] != 1) {
|
||||
newConstantShape[i] = constantShape[i];
|
||||
}
|
||||
}
|
||||
|
||||
const auto newConstant = fold<ngraph::opset1::Broadcast>(
|
||||
constant,
|
||||
ngraph::opset1::Constant::create(ngraph::element::i32, { newConstantShape.size() }, newConstantShape));
|
||||
constant = as_type_ptr<ngraph::opset1::Constant>(newConstant);
|
||||
}
|
||||
|
||||
const auto stridedSlice = as_type_ptr<ngraph::opset1::StridedSlice>(strSlice);
|
||||
return fold<ngraph::opset1::StridedSlice>(
|
||||
constant,
|
||||
stridedSlice->get_input_node_shared_ptr(1),
|
||||
stridedSlice->get_input_node_shared_ptr(2),
|
||||
stridedSlice->get_input_node_shared_ptr(3),
|
||||
stridedSlice->get_begin_mask(),
|
||||
stridedSlice->get_end_mask(),
|
||||
stridedSlice->get_new_axis_mask(),
|
||||
stridedSlice->get_shrink_axis_mask(),
|
||||
stridedSlice->get_ellipsis_mask());
|
||||
}
|
||||
|
||||
StridedSliceTransformation::StridedSliceTransformation(const Params& params) : LayerTransformation(params) {}
|
||||
|
||||
void StridedSliceTransformation::registerMatcherIn(GraphRewrite& pass, TransformationContext& context) const {
|
||||
addPattern(pass,
|
||||
context,
|
||||
make_op_pattern<opset1::StridedSlice>({
|
||||
make_op_label<opset1::Multiply>(),
|
||||
make_op_label<opset1::Constant>(),
|
||||
make_op_label<opset1::Constant>(),
|
||||
make_op_label<opset1::Constant>() }));
|
||||
}
|
||||
|
||||
bool StridedSliceTransformation::transform(TransformationContext& context, ngraph::pattern::Matcher& m) const {
|
||||
if (!StridedSliceTransformation::canBeTransformed(context, m.get_match_root())) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const auto stridedSlice = separateInStandaloneBranch(m.get_match_root());
|
||||
const auto dequantization = NetworkHelper::getDequantization(stridedSlice);
|
||||
|
||||
if (dequantization.subtract) {
|
||||
const auto subConst = NetworkHelper::getConstantInput(dequantization.subtract);
|
||||
const size_t subConstIdx = NetworkHelper::getChildInputIndex(subConst, dequantization.subtract);
|
||||
|
||||
const auto newSubConst = stridedSliceDeqConstant(stridedSlice, subConst);
|
||||
dequantization.subtract->set_argument(subConstIdx, newSubConst);
|
||||
}
|
||||
|
||||
const auto mulConst = NetworkHelper::getConstantInput(dequantization.multiply);
|
||||
const size_t mulConstIdx = NetworkHelper::getChildInputIndex(mulConst, dequantization.multiply);
|
||||
|
||||
const auto newMulConst = stridedSliceDeqConstant(stridedSlice, mulConst);
|
||||
dequantization.multiply->set_argument(mulConstIdx, newMulConst);
|
||||
|
||||
moveDequantizationAfter(context, stridedSlice, dequantization, false);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool StridedSliceTransformation::canBeTransformed(const TransformationContext& context, std::shared_ptr<Node> operation) const {
|
||||
if (!is_type<ngraph::opset1::StridedSlice>(operation)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !NetworkHelper::getDequantization(operation).empty();
|
||||
}
|
||||
|
||||
bool StridedSliceTransformation::isPrecisionPreserved(std::shared_ptr<Node> layer) const noexcept {
|
||||
return true;
|
||||
}
|
||||
} // namespace low_precision
|
||||
} // namespace pass
|
||||
} // namespace ngraph
|
@ -44,6 +44,7 @@
|
||||
#include "low_precision/squeeze.hpp"
|
||||
#include "low_precision/subtract.hpp"
|
||||
#include "low_precision/split.hpp"
|
||||
#include "low_precision/strided_slice.hpp"
|
||||
#include "low_precision/transpose.hpp"
|
||||
#include "low_precision/unsqueeze.hpp"
|
||||
#include "low_precision/variadic_split.hpp"
|
||||
@ -218,6 +219,7 @@ LowPrecisionTransformations LowPrecisionTransformer::getAllTransformations(const
|
||||
add<ReluTransformation, opset1::Relu>(params).
|
||||
add<ReshapeTransformation, opset1::Reshape>(params).
|
||||
add<SqueezeTransformation, opset1::Squeeze>(params).
|
||||
add<StridedSliceTransformation, opset1::StridedSlice>(params).
|
||||
add<TransposeTransformation, opset1::Transpose>(params).
|
||||
add<UnsqueezeTransformation, opset1::Unsqueeze>(params).
|
||||
add<InterpolateTransformation, opset4::Interpolate>(params).
|
||||
|
@ -0,0 +1,376 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "layer_transformation.hpp"
|
||||
|
||||
#include <string>
|
||||
#include <sstream>
|
||||
#include <memory>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include <utility>
|
||||
#include <transformations/utils/utils.hpp>
|
||||
|
||||
#include "common_test_utils/ngraph_test_utils.hpp"
|
||||
#include "simple_low_precision_transformer.hpp"
|
||||
#include "low_precision/strided_slice.hpp"
|
||||
|
||||
#include "lpt_ngraph_functions/strided_slice_function.hpp"
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
|
||||
using namespace testing;
|
||||
using namespace ngraph::pass;
|
||||
using namespace ngraph::builder::subgraph;
|
||||
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& values) {
|
||||
os << "{ ";
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
os << values[i];
|
||||
if (i != (values.size() - 1ul)) {
|
||||
os << ", ";
|
||||
}
|
||||
}
|
||||
os << " }";
|
||||
return os;
|
||||
}
|
||||
|
||||
class StridedSliceTransformationTestValues {
|
||||
public:
|
||||
class Actual {
|
||||
public:
|
||||
ngraph::element::Type inputPrecision;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantization;
|
||||
};
|
||||
|
||||
class Expected {
|
||||
public:
|
||||
ngraph::element::Type inputPrecision;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantizationBefore;
|
||||
ngraph::element::Type preicsionAfterOperation;
|
||||
ngraph::builder::subgraph::DequantizationOperations dequantizationAfter;
|
||||
};
|
||||
|
||||
struct LayerParams {
|
||||
std::vector<int64_t> begin;
|
||||
std::vector<int64_t> end;
|
||||
std::vector<int64_t> strides;
|
||||
std::vector<int64_t> beginMask;
|
||||
std::vector<int64_t> endMask;
|
||||
std::vector<int64_t> newAxisMask;
|
||||
std::vector<int64_t> shrinkAxisMask;
|
||||
std::vector<int64_t> elipsisMask;
|
||||
};
|
||||
|
||||
ngraph::Shape inputShape;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
LayerParams layerParams;
|
||||
Actual actual;
|
||||
Expected expected;
|
||||
};
|
||||
|
||||
class StridedSliceTransformation : public LayerTransformation, public testing::WithParamInterface<StridedSliceTransformationTestValues> {
|
||||
public:
|
||||
void SetUp() override {
|
||||
const StridedSliceTransformationTestValues testValues = GetParam();
|
||||
|
||||
actualFunction = ngraph::builder::subgraph::StridedSliceFunction::getOriginal(
|
||||
testValues.actual.inputPrecision,
|
||||
testValues.inputShape,
|
||||
testValues.actual.dequantization,
|
||||
testValues.layerParams.begin,
|
||||
testValues.layerParams.end,
|
||||
testValues.layerParams.strides,
|
||||
testValues.layerParams.beginMask,
|
||||
testValues.layerParams.endMask,
|
||||
testValues.layerParams.newAxisMask,
|
||||
testValues.layerParams.shrinkAxisMask,
|
||||
testValues.layerParams.elipsisMask);
|
||||
|
||||
SimpleLowPrecisionTransformer transformer;
|
||||
transformer.add<ngraph::pass::low_precision::StridedSliceTransformation, ngraph::opset1::StridedSlice>(testValues.params);
|
||||
transformer.transform(actualFunction);
|
||||
|
||||
referenceFunction = ngraph::builder::subgraph::StridedSliceFunction::getReference(
|
||||
testValues.expected.inputPrecision,
|
||||
testValues.inputShape,
|
||||
testValues.layerParams.begin,
|
||||
testValues.layerParams.end,
|
||||
testValues.layerParams.strides,
|
||||
testValues.layerParams.beginMask,
|
||||
testValues.layerParams.endMask,
|
||||
testValues.layerParams.newAxisMask,
|
||||
testValues.layerParams.shrinkAxisMask,
|
||||
testValues.layerParams.elipsisMask,
|
||||
testValues.expected.dequantizationBefore,
|
||||
testValues.expected.preicsionAfterOperation,
|
||||
testValues.expected.dequantizationAfter);
|
||||
}
|
||||
|
||||
static std::string getTestCaseName(testing::TestParamInfo<StridedSliceTransformationTestValues> obj) {
|
||||
const StridedSliceTransformationTestValues testValues = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result <<
|
||||
testValues.inputShape << testValues.actual.inputPrecision << "_" << toString(testValues.params) <<
|
||||
testValues.actual.dequantization << "_strided_slice_params_" << testValues.layerParams.begin <<
|
||||
testValues.layerParams.end << testValues.layerParams.beginMask <<
|
||||
testValues.layerParams.endMask << testValues.layerParams.strides;
|
||||
return result.str();
|
||||
}
|
||||
};
|
||||
|
||||
TEST_P(StridedSliceTransformation, CompareFunctions) {
|
||||
actualFunction->validate_nodes_and_infer_types();
|
||||
auto res = compare_functions(referenceFunction, actualFunction, true, true, true);
|
||||
ASSERT_TRUE(res.first) << res.second;
|
||||
}
|
||||
|
||||
StridedSliceTransformationTestValues::LayerParams channelSlice = {
|
||||
{ 0, 0, 0, 0 }, // begin
|
||||
{ 1, 2, 1, 1 }, // end
|
||||
{ 1, 1, 1, 1 }, // strided
|
||||
{ 1, 0, 1, 1 }, // beginMask
|
||||
{ 1, 0, 1, 1 }, // endMask
|
||||
{}, // newAxisMask
|
||||
{}, // shrinkAxisMask
|
||||
{} // elipsisMask
|
||||
};
|
||||
|
||||
StridedSliceTransformationTestValues::LayerParams specialDimensionSlice = {
|
||||
{ 0, 0, 0, 0 },
|
||||
{ 1, 3, 20, 24 },
|
||||
{ 1, 1, 1, 1 },
|
||||
{ 1, 1, 0, 1 },
|
||||
{ 1, 1, 0, 1 },
|
||||
{},
|
||||
{},
|
||||
{}
|
||||
};
|
||||
|
||||
const std::vector<StridedSliceTransformationTestValues> stridedSliceTransformationTestValues = {
|
||||
// U8: channel slice, per-tensor quantization
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
channelSlice,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, { 128.f }, { 0.1f }}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, { 128.f }, { 0.1f }}
|
||||
}
|
||||
},
|
||||
// U8: channel slice, per-channel quantization with the same values
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
channelSlice,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{ 128.f, 128.f, 128.f }}, {{ 0.1f, 0.1f, 0.1f }}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, { 128.f }, { 0.1f }}
|
||||
}
|
||||
},
|
||||
// U8: channel slice, per-channel quantization with different values
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
channelSlice,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{ 128.f, 64.f, 128.f }}, {{ 0.1f, 0.01f, 1.f }}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{ 128.f, 64.f }}, {{ 0.1f, 0.01f }}}
|
||||
}
|
||||
},
|
||||
// U8: special dimension slice, per-channel quantization with different values
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
specialDimensionSlice,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{ 128.f, 64.f, 128.f }}, {{ 0.1f, 0.01f, 1.f }}}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {{ 128.f, 64.f, 128.f }}, {{ 0.1f, 0.01f, 1.f }}}
|
||||
}
|
||||
},
|
||||
// U8: without subtract
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
channelSlice,
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {}, { 0.1f }}
|
||||
},
|
||||
{
|
||||
ngraph::element::u8,
|
||||
{},
|
||||
ngraph::element::u8,
|
||||
{{ngraph::element::f32}, {}, { 0.1f }}
|
||||
}
|
||||
},
|
||||
// U8: without convert
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
specialDimensionSlice,
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{{}, { 128.f }, { 0.1f }}
|
||||
},
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{},
|
||||
ngraph::element::f32,
|
||||
{{}, { 128.f }, { 0.1f }}
|
||||
}
|
||||
},
|
||||
// I8: channel slice, per-tensor quantization
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsI8I8(),
|
||||
channelSlice,
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{{ngraph::element::f32}, { 32.f }, { 0.1f }}
|
||||
},
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{},
|
||||
ngraph::element::i8,
|
||||
{{ngraph::element::f32}, { 32.f }, { 0.1f }}
|
||||
}
|
||||
},
|
||||
// I8: channel slice, per-channel quantization with the same values
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsI8I8(),
|
||||
channelSlice,
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{{ngraph::element::f32}, {{ 32.f, 32.f, 32.f }}, {{ 0.1f, 0.1f, 0.1f }}}
|
||||
},
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{},
|
||||
ngraph::element::i8,
|
||||
{{ngraph::element::f32}, { 32.f }, { 0.1f }}
|
||||
}
|
||||
},
|
||||
// I8: channel slice, per-channel quantization with different values
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsI8I8(),
|
||||
channelSlice,
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{{ngraph::element::f32}, {{ 32.f, 64.f, 32.f }}, {{ 0.1f, 0.01f, 1.f }}}
|
||||
},
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{},
|
||||
ngraph::element::i8,
|
||||
{{ngraph::element::f32}, {{ 32.f, 64.f }}, {{ 0.1f, 0.01f }}}
|
||||
}
|
||||
},
|
||||
// I8: special dimension slice, per-channel quantization with different values
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsI8I8(),
|
||||
specialDimensionSlice,
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{{ngraph::element::f32}, {{ 32.f, 64.f, 32.f }}, {{ 0.1f, 0.01f, 1.f }}}
|
||||
},
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{},
|
||||
ngraph::element::i8,
|
||||
{{ngraph::element::f32}, {{ 32.f, 64.f, 32.f }}, {{ 0.1f, 0.01f, 1.f }}}
|
||||
}
|
||||
},
|
||||
// I8: channel slice, quantization by special dimension
|
||||
{
|
||||
ngraph::Shape{1, 3, 4, 4},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
channelSlice,
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{32.f, 64.f, 32.f, 64.f}, ngraph::element::f32, {1, 1, 4, 1}},
|
||||
{{3.f, 2.f, 1.f, 3.f}, ngraph::element::f32, {1, 1, 4, 1}}
|
||||
}
|
||||
},
|
||||
{
|
||||
ngraph::element::i8,
|
||||
{},
|
||||
ngraph::element::i8,
|
||||
{
|
||||
{ngraph::element::f32},
|
||||
{{32.f, 64.f, 32.f, 64.f}, ngraph::element::f32, {1, 1, 4, 1}},
|
||||
{{3.f, 2.f, 1.f, 3.f}, ngraph::element::f32, {1, 1, 4, 1}}
|
||||
}
|
||||
}
|
||||
},
|
||||
// channel slice, not update precisions
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsU8I8().setUpdatePrecisions(false),
|
||||
channelSlice,
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{{}, { 128.f }, { 0.1f }}
|
||||
},
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{},
|
||||
ngraph::element::f32,
|
||||
{{}, { 128.f }, { 0.1f }}
|
||||
}
|
||||
},
|
||||
// channel slice, no dequantization
|
||||
{
|
||||
ngraph::Shape{1, 3, 24, 24},
|
||||
LayerTransformation::createParamsU8I8(),
|
||||
channelSlice,
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{}
|
||||
},
|
||||
{
|
||||
ngraph::element::f32,
|
||||
{},
|
||||
ngraph::element::f32,
|
||||
{}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(
|
||||
smoke_LPT,
|
||||
StridedSliceTransformation,
|
||||
::testing::ValuesIn(stridedSliceTransformationTestValues),
|
||||
StridedSliceTransformation::getTestCaseName);
|
@ -0,0 +1,101 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "low_precision_transformations/strided_slice_transformation.hpp"
|
||||
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
const std::vector<ngraph::element::Type> netPrecisions = {
|
||||
ngraph::element::f32,
|
||||
// ngraph::element::f16
|
||||
};
|
||||
|
||||
const std::vector<ngraph::pass::low_precision::LayerTransformation::Params> trasformationParamValues = {
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams().setUpdatePrecisions(true),
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams().setUpdatePrecisions(false),
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsI8I8(),
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParamsU8I8()
|
||||
};
|
||||
|
||||
const std::vector<LayerTestsDefinitions::StridedSliceTransformationParam> params = {
|
||||
// channel slice, tensor quantization
|
||||
{
|
||||
{ 256ul, ngraph::Shape{ 1, 1, 1, 1 }, { 0.f }, { 25.5f }, { 0.f }, { 12.8f } },
|
||||
{ 0, 0, 0, 0 }, // begin
|
||||
{ 1, 2, 1, 1 }, // end
|
||||
{ 1, 1, 1, 1 }, // strided
|
||||
{ 1, 0, 1, 1 }, // beginMask
|
||||
{ 1, 0, 1, 1 }, // endMask
|
||||
{},// newAxisMask
|
||||
{},// shrinkAxisMask
|
||||
{}// elipsisMask
|
||||
},
|
||||
// special dimension slice, tensor quantization
|
||||
{
|
||||
{ 256ul, ngraph::Shape{ 1, 1, 1, 1 }, { 0.f }, { 25.5f }, { 0.f }, { 12.8f } },
|
||||
{ 0, 0, 0, 0 },
|
||||
{ 1, 3, 20, 24 },
|
||||
{ 1, 1, 1, 1 },
|
||||
{ 1, 1, 0, 1 },
|
||||
{ 1, 1, 0, 1 },
|
||||
{},
|
||||
{},
|
||||
{}
|
||||
},
|
||||
// channel slice, per-channel quantization
|
||||
{
|
||||
{
|
||||
256ul,
|
||||
ngraph::Shape{ 1, 3, 1, 1 },
|
||||
{ 0.f, 0.f, 0.f },
|
||||
{ 255.f, 25.5f, 2.55f },
|
||||
{ 0.f, 0.f, 0.f },
|
||||
{ 255.f, 25.5f, 2.55f },
|
||||
},
|
||||
{ 0, 0, 0, 0 },
|
||||
{ 1, 2, 1, 1 },
|
||||
{ 1, 1, 1, 1 },
|
||||
{ 1, 0, 1, 1 },
|
||||
{ 1, 0, 1, 1 },
|
||||
{},
|
||||
{},
|
||||
{}
|
||||
},
|
||||
// special dimension slice, per-channel quantization
|
||||
{
|
||||
{
|
||||
256ul,
|
||||
ngraph::Shape{ 1, 3, 1, 1 },
|
||||
{ 0.f, 0.f, 0.f },
|
||||
{ 255.f, 25.5f, 2.55f },
|
||||
{ 0.f, 0.f, 0.f },
|
||||
{ 255.f, 25.5f, 2.55f },
|
||||
},
|
||||
{ 0, 0, 0, 0 },
|
||||
{ 1, 3, 20, 24 },
|
||||
{ 1, 1, 1, 1 },
|
||||
{ 1, 1, 0, 1 },
|
||||
{ 1, 1, 0, 1 },
|
||||
{},
|
||||
{},
|
||||
{}
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_LPT, StridedSliceTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(ngraph::Shape({ 1, 3, 24, 24 })),
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU),
|
||||
::testing::ValuesIn(trasformationParamValues),
|
||||
::testing::ValuesIn(params)),
|
||||
StridedSliceTransformation::getTestCaseName);
|
||||
|
||||
} // namespace
|
@ -0,0 +1,98 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
|
||||
#include "low_precision_transformations/strided_slice_transformation.hpp"
|
||||
|
||||
|
||||
using namespace LayerTestsDefinitions;
|
||||
|
||||
namespace {
|
||||
const std::vector<ngraph::element::Type> netPrecisions = {
|
||||
ngraph::element::f32,
|
||||
// ngraph::element::f16
|
||||
};
|
||||
|
||||
const std::vector<ngraph::pass::low_precision::LayerTransformation::Params> trasformationParamValues = {
|
||||
LayerTestsUtils::LayerTransformationParamsNGraphFactory::createParams(),
|
||||
};
|
||||
|
||||
const std::vector<LayerTestsDefinitions::StridedSliceTransformationParam> params = {
|
||||
// channel slice, tensor quantization
|
||||
{
|
||||
{ 256ul, ngraph::Shape{ 1, 1, 1, 1 }, { 0.f }, { 25.5f }, { 0.f }, { 12.8f } },
|
||||
{ 0, 0, 0, 0 }, // begin
|
||||
{ 1, 2, 1, 1 }, // end
|
||||
{ 1, 1, 1, 1 }, // strided
|
||||
{ 1, 0, 1, 1 }, // beginMask
|
||||
{ 1, 0, 1, 1 }, // endMask
|
||||
{},// newAxisMask
|
||||
{},// shrinkAxisMask
|
||||
{}// elipsisMask
|
||||
},
|
||||
// special dimension slice, tensor quantization
|
||||
{
|
||||
{ 256ul, ngraph::Shape{ 1, 1, 1, 1 }, { 0.f }, { 25.5f }, { 0.f }, { 12.8f } },
|
||||
{ 0, 0, 0, 0 },
|
||||
{ 1, 3, 20, 24 },
|
||||
{ 1, 1, 1, 1 },
|
||||
{ 1, 1, 0, 1 },
|
||||
{ 1, 1, 0, 1 },
|
||||
{},
|
||||
{},
|
||||
{}
|
||||
},
|
||||
// channel slice, per-channel quantization
|
||||
{
|
||||
{
|
||||
256ul,
|
||||
ngraph::Shape{ 1, 3, 1, 1 },
|
||||
{ 0.f, 0.f, 0.f },
|
||||
{ 255.f, 25.5f, 2.55f },
|
||||
{ 0.f, 0.f, 0.f },
|
||||
{ 255.f, 25.5f, 2.55f },
|
||||
},
|
||||
{ 0, 0, 0, 0 },
|
||||
{ 1, 2, 1, 1 },
|
||||
{ 1, 1, 1, 1 },
|
||||
{ 1, 0, 1, 1 },
|
||||
{ 1, 0, 1, 1 },
|
||||
{},
|
||||
{},
|
||||
{}
|
||||
},
|
||||
// special dimension slice, per-channel quantization
|
||||
{
|
||||
{
|
||||
256ul,
|
||||
ngraph::Shape{ 1, 3, 1, 1 },
|
||||
{ 0.f, 0.f, 0.f },
|
||||
{ 255.f, 25.5f, 2.55f },
|
||||
{ 0.f, 0.f, 0.f },
|
||||
{ 255.f, 25.5f, 2.55f },
|
||||
},
|
||||
{ 0, 0, 0, 0 },
|
||||
{ 1, 3, 20, 24 },
|
||||
{ 1, 1, 1, 1 },
|
||||
{ 1, 1, 0, 1 },
|
||||
{ 1, 1, 0, 1 },
|
||||
{},
|
||||
{},
|
||||
{}
|
||||
}
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_CASE_P(smoke_LPT, StridedSliceTransformation,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(netPrecisions),
|
||||
::testing::Values(ngraph::Shape({ 1, 3, 24, 24 })),
|
||||
::testing::Values(CommonTestUtils::DEVICE_GPU),
|
||||
::testing::ValuesIn(trasformationParamValues),
|
||||
::testing::ValuesIn(params)),
|
||||
StridedSliceTransformation::getTestCaseName);
|
||||
|
||||
} // namespace
|
@ -0,0 +1,45 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "shared_test_classes/base/low_precision_transformations/layer_transformation.hpp"
|
||||
#include "lpt_ngraph_functions/common/fake_quantize_on_data.hpp"
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
class StridedSliceTransformationParam {
|
||||
public:
|
||||
ngraph::builder::subgraph::FakeQuantizeOnData fakeQuantize;
|
||||
std::vector<int64_t> begin;
|
||||
std::vector<int64_t> end;
|
||||
std::vector<int64_t> strides;
|
||||
std::vector<int64_t> beginMask;
|
||||
std::vector<int64_t> endMask;
|
||||
std::vector<int64_t> newAxisMask;
|
||||
std::vector<int64_t> shrinkAxisMask;
|
||||
std::vector<int64_t> elipsisMask;
|
||||
};
|
||||
|
||||
typedef std::tuple<
|
||||
ngraph::element::Type,
|
||||
ngraph::Shape,
|
||||
std::string,
|
||||
ngraph::pass::low_precision::LayerTransformation::Params,
|
||||
StridedSliceTransformationParam
|
||||
> StridedSliceTransformationParams;
|
||||
|
||||
class StridedSliceTransformation :
|
||||
public testing::WithParamInterface<StridedSliceTransformationParams>,
|
||||
public LayerTestsUtils::LayerTransformation {
|
||||
public:
|
||||
static std::string getTestCaseName(testing::TestParamInfo<StridedSliceTransformationParams> obj);
|
||||
|
||||
protected:
|
||||
void SetUp() override;
|
||||
|
||||
private:
|
||||
void validate();
|
||||
};
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,86 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "low_precision_transformations/strided_slice_transformation.hpp"
|
||||
#include <sstream>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
|
||||
#include "lpt_ngraph_functions/strided_slice_function.hpp"
|
||||
|
||||
namespace LayerTestsDefinitions {
|
||||
|
||||
inline std::ostream& operator<<(std::ostream& os, const std::vector<int64_t>& values) {
|
||||
os << "{ ";
|
||||
for (size_t i = 0; i < values.size(); ++i) {
|
||||
os << values[i];
|
||||
if (i != (values.size() - 1ul)) {
|
||||
os << ", ";
|
||||
}
|
||||
}
|
||||
os << " }";
|
||||
return os;
|
||||
}
|
||||
|
||||
std::string StridedSliceTransformation::getTestCaseName(testing::TestParamInfo<StridedSliceTransformationParams> obj) {
|
||||
ngraph::element::Type netPrecision;
|
||||
ngraph::Shape inputShape;
|
||||
std::string targetDevice;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
StridedSliceTransformationParam param;;
|
||||
std::tie(netPrecision, inputShape, targetDevice, params, param) = obj.param;
|
||||
|
||||
std::ostringstream result;
|
||||
result << getTestCaseNameByParams(netPrecision, inputShape, targetDevice, params) << "_" <<
|
||||
param.fakeQuantize << "_" << param.begin << "_" << param.beginMask << "_" <<
|
||||
param.end << "_" << param.endMask << "_" << param.strides << "_" << param.newAxisMask <<
|
||||
param.shrinkAxisMask << "_" << param.elipsisMask;
|
||||
return result.str();
|
||||
}
|
||||
|
||||
void StridedSliceTransformation::SetUp() {
|
||||
ngraph::element::Type netPrecision;
|
||||
ngraph::Shape inputShape;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
StridedSliceTransformationParam param;
|
||||
std::tie(netPrecision, inputShape, targetDevice, params, param) = this->GetParam();
|
||||
|
||||
function = ngraph::builder::subgraph::StridedSliceFunction::getOriginal(
|
||||
netPrecision,
|
||||
inputShape,
|
||||
param.fakeQuantize,
|
||||
param.begin,
|
||||
param.end,
|
||||
param.strides,
|
||||
param.beginMask,
|
||||
param.endMask,
|
||||
param.newAxisMask,
|
||||
param.shrinkAxisMask,
|
||||
param.elipsisMask);
|
||||
|
||||
validate();
|
||||
}
|
||||
|
||||
void StridedSliceTransformation::validate() {
|
||||
ngraph::element::Type netPrecision;
|
||||
ngraph::Shape inputShape;
|
||||
std::string targetDevice;
|
||||
ngraph::pass::low_precision::LayerTransformation::Params params;
|
||||
StridedSliceTransformationParam param;
|
||||
std::tie(netPrecision, inputShape, targetDevice, params, param) = this->GetParam();
|
||||
|
||||
const auto transformed = transformNGraph(params, getLowPrecisionTransformationsNGraph(params));
|
||||
|
||||
const auto output = transformed->get_output_op(0);
|
||||
const auto layer = output->get_input_node_shared_ptr(0);
|
||||
const std::string typeName = layer->get_type_name();
|
||||
ASSERT_EQ("ScaleShiftIE", typeName);
|
||||
}
|
||||
|
||||
TEST_P(StridedSliceTransformation, CompareWithRefImpl) {
|
||||
Run();
|
||||
};
|
||||
|
||||
} // namespace LayerTestsDefinitions
|
@ -0,0 +1,64 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <memory>
|
||||
#include <ngraph/ngraph.hpp>
|
||||
#include <low_precision/layer_transformation.hpp>
|
||||
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
#include "lpt_ngraph_functions/common/builders.hpp"
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
class StridedSliceFunction {
|
||||
public:
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||
const std::vector<int64_t>& begin,
|
||||
const std::vector<int64_t>& end,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& beginMask,
|
||||
const std::vector<int64_t>& endMask,
|
||||
const std::vector<int64_t>& newAxisMask,
|
||||
const std::vector<int64_t>& shrinkAxisMask,
|
||||
const std::vector<int64_t>& elipsisMask);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getOriginal(
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::builder::subgraph::FakeQuantizeOnData& fakeQuantize,
|
||||
const std::vector<int64_t>& begin,
|
||||
const std::vector<int64_t>& end,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& beginMask,
|
||||
const std::vector<int64_t>& endMask,
|
||||
const std::vector<int64_t>& newAxisMask,
|
||||
const std::vector<int64_t>& shrinkAxisMask,
|
||||
const std::vector<int64_t>& elipsisMask);
|
||||
|
||||
static std::shared_ptr<ngraph::Function> getReference(
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const std::vector<int64_t>& begin,
|
||||
const std::vector<int64_t>& end,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& beginMask,
|
||||
const std::vector<int64_t>& endMask,
|
||||
const std::vector<int64_t>& newAxisMask,
|
||||
const std::vector<int64_t>& shrinkAxisMask,
|
||||
const std::vector<int64_t>& elipsisMask,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
|
||||
const ngraph::element::Type precisionAfterOperation,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter);
|
||||
};
|
||||
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
@ -0,0 +1,131 @@
|
||||
// Copyright (C) 2021 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "low_precision/network_helper.hpp"
|
||||
#include "low_precision/layer_transformation.hpp"
|
||||
|
||||
#include "ngraph/opsets/opset1.hpp"
|
||||
|
||||
#include "lpt_ngraph_functions/common/dequantization_operations.hpp"
|
||||
#include "ngraph_functions/subgraph_builders.hpp"
|
||||
#include "lpt_ngraph_functions/strided_slice_function.hpp"
|
||||
|
||||
using namespace ngraph::pass::low_precision;
|
||||
|
||||
namespace ngraph {
|
||||
namespace builder {
|
||||
namespace subgraph {
|
||||
|
||||
std::shared_ptr<ngraph::Function> StridedSliceFunction::getOriginal(
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantization,
|
||||
const std::vector<int64_t>& begin,
|
||||
const std::vector<int64_t>& end,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& beginMask,
|
||||
const std::vector<int64_t>& endMask,
|
||||
const std::vector<int64_t>& newAxisMask,
|
||||
const std::vector<int64_t>& shrinkAxisMask,
|
||||
const std::vector<int64_t>& elipsisMask) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
|
||||
input->set_friendly_name("input");
|
||||
const auto deq = makeDequantization(input, dequantization);
|
||||
|
||||
const auto beginParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ begin.size() }, begin);
|
||||
const auto endParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ end.size() }, end);
|
||||
const auto stridesParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ strides.size() }, strides);
|
||||
|
||||
const auto stridedSlice = std::make_shared<ngraph::opset1::StridedSlice>(
|
||||
deq, beginParam, endParam, stridesParam,
|
||||
beginMask, endMask, newAxisMask,
|
||||
shrinkAxisMask, elipsisMask);
|
||||
stridedSlice->set_friendly_name("StridedSlice");
|
||||
|
||||
const auto res = std::make_shared<ngraph::opset1::Result>(stridedSlice);
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::ResultVector{ res },
|
||||
ngraph::ParameterVector{ input },
|
||||
"StridedSliceTransformation");
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> StridedSliceFunction::getOriginal(
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const ngraph::builder::subgraph::FakeQuantizeOnData& fakeQuantize,
|
||||
const std::vector<int64_t>& begin,
|
||||
const std::vector<int64_t>& end,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& beginMask,
|
||||
const std::vector<int64_t>& endMask,
|
||||
const std::vector<int64_t>& newAxisMask,
|
||||
const std::vector<int64_t>& shrinkAxisMask,
|
||||
const std::vector<int64_t>& elipsisMask) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
|
||||
input->set_friendly_name("input");
|
||||
const auto fqOnData = makeFakeQuantize(input, inputPrecision, fakeQuantize);
|
||||
|
||||
const auto beginParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ begin.size() }, begin);
|
||||
const auto endParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ end.size() }, end);
|
||||
const auto stridesParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ strides.size() }, strides);
|
||||
|
||||
const auto stridedSlice = std::make_shared<ngraph::opset1::StridedSlice>(
|
||||
fqOnData, beginParam, endParam, stridesParam,
|
||||
beginMask, endMask, newAxisMask,
|
||||
shrinkAxisMask, elipsisMask);
|
||||
stridedSlice->set_friendly_name("StridedSlice");
|
||||
|
||||
const auto res = std::make_shared<ngraph::opset1::Result>(stridedSlice);
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::ResultVector{ res },
|
||||
ngraph::ParameterVector{ input },
|
||||
"StridedSliceTransformation");
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
std::shared_ptr<ngraph::Function> StridedSliceFunction::getReference(
|
||||
const ngraph::element::Type inputPrecision,
|
||||
const ngraph::Shape& inputShape,
|
||||
const std::vector<int64_t>& begin,
|
||||
const std::vector<int64_t>& end,
|
||||
const std::vector<int64_t>& strides,
|
||||
const std::vector<int64_t>& beginMask,
|
||||
const std::vector<int64_t>& endMask,
|
||||
const std::vector<int64_t>& newAxisMask,
|
||||
const std::vector<int64_t>& shrinkAxisMask,
|
||||
const std::vector<int64_t>& elipsisMask,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationBefore,
|
||||
const ngraph::element::Type precisionAfterOperation,
|
||||
const ngraph::builder::subgraph::DequantizationOperations& dequantizationAfter) {
|
||||
const auto input = std::make_shared<ngraph::opset1::Parameter>(inputPrecision, inputShape);
|
||||
input->set_friendly_name("input");
|
||||
const auto deqBefore = makeDequantization(input, dequantizationBefore);
|
||||
|
||||
const auto beginParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ begin.size() }, begin);
|
||||
const auto endParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ end.size() }, end);
|
||||
const auto stridesParam = ngraph::op::Constant::create(ngraph::element::i64, ngraph::Shape{ strides.size() }, strides);
|
||||
|
||||
const auto stridedSlice = std::make_shared<ngraph::opset1::StridedSlice>(
|
||||
deqBefore, beginParam, endParam, stridesParam,
|
||||
beginMask, endMask, newAxisMask,
|
||||
shrinkAxisMask, elipsisMask);
|
||||
|
||||
const auto deqAfter = makeDequantization(stridedSlice, dequantizationAfter);
|
||||
deqAfter->set_friendly_name("StridedSlice");
|
||||
|
||||
const auto res = std::make_shared<ngraph::opset1::Result>(deqAfter);
|
||||
const auto function = std::make_shared<ngraph::Function>(
|
||||
ngraph::ResultVector{ res },
|
||||
ngraph::ParameterVector{ input },
|
||||
"StridedSliceTransformation");
|
||||
|
||||
return function;
|
||||
}
|
||||
|
||||
} // namespace subgraph
|
||||
} // namespace builder
|
||||
} // namespace ngraph
|
@ -34,8 +34,8 @@ namespace ngraph
|
||||
class NGRAPH_API StridedSlice : public Op
|
||||
{
|
||||
public:
|
||||
static constexpr NodeTypeInfo type_info{"StridedSlice", 1};
|
||||
const NodeTypeInfo& get_type_info() const override { return type_info; }
|
||||
NGRAPH_RTTI_DECLARATION;
|
||||
|
||||
StridedSlice() = default;
|
||||
|
||||
/// \brief Constructs a dynamic tensor strided slice operation.
|
||||
|
@ -34,7 +34,7 @@
|
||||
using namespace std;
|
||||
using namespace ngraph;
|
||||
|
||||
constexpr NodeTypeInfo op::v1::StridedSlice::type_info;
|
||||
NGRAPH_RTTI_DEFINITION(op::v1::StridedSlice, "StridedSlice", 1);
|
||||
|
||||
op::v1::StridedSlice::StridedSlice(const Output<Node>& data,
|
||||
const Output<Node>& begin,
|
||||
@ -265,9 +265,9 @@ namespace strided_slice
|
||||
const AxisSet& ellipsis_mask,
|
||||
const HostTensorPtr& out)
|
||||
{
|
||||
std::vector<int64_t> begin_const = read_vector<int64_t>(begin);
|
||||
std::vector<int64_t> end_const = read_vector<int64_t>(end);
|
||||
std::vector<int64_t> stride_const = read_vector<int64_t>(stride);
|
||||
std::vector<int64_t> begin_const = host_tensor_2_vector<int64_t>(begin);
|
||||
std::vector<int64_t> end_const = host_tensor_2_vector<int64_t>(end);
|
||||
std::vector<int64_t> stride_const = host_tensor_2_vector<int64_t>(stride);
|
||||
SlicePlan slice_plan = make_slice_plan(in->get_shape(),
|
||||
begin_const,
|
||||
end_const,
|
||||
|
Loading…
Reference in New Issue
Block a user