[NGraph] Add scatterNDUpdate and scatterUpdate reference implementations (#1494)

This commit is contained in:
Maxim Andronov 2020-08-07 16:09:28 +03:00 committed by GitHub
parent caa38130b9
commit 4054364fbf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 235 additions and 200 deletions

View File

@ -64,7 +64,6 @@ set(LAYERS
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_tree.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/grn.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/non_max_suppression.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/scatter.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/log_softmax.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/math.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/one_hot.cpp

View File

@ -51,7 +51,6 @@ MKLDNN_EXTENSION_NODE(FillImpl, Fill);
MKLDNN_EXTENSION_NODE(UniqueImpl, Unique);
MKLDNN_EXTENSION_NODE(PSROIPoolingImpl, PSROIPooling);
MKLDNN_EXTENSION_NODE(DepthToSpaceImpl, DepthToSpace);
MKLDNN_EXTENSION_NODE(ScatterImpl, ScatterUpdate);
MKLDNN_EXTENSION_NODE(OneHotImpl, OneHot);
MKLDNN_EXTENSION_NODE(BroadcastImpl, Broadcast);
MKLDNN_EXTENSION_NODE(ExperimentalSparseWeightedReduceImpl, ExperimentalSparseWeightedSum);

View File

@ -1,188 +0,0 @@
// Copyright (C) 2018-2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "base.hpp"
#include <cmath>
#include <string>
#include <vector>
#include <cassert>
#include <algorithm>
#include <limits>
#include "ie_parallel.hpp"
#include "common/simple_copy.h"
namespace InferenceEngine {
namespace Extensions {
namespace Cpu {
class ScatterImpl: public ExtLayerBase {
public:
explicit ScatterImpl(const CNNLayer* layer) {
try {
if (layer->insData.size() != 3 || layer->outData.size() != 1)
THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output tensors!";
Precision inIdxPrecision = layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getPrecision();
if (inIdxPrecision != Precision::FP32 && inIdxPrecision != Precision::I32)
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indexes' precision. Only FP32 or I32 are supported!";
Precision inDataPrecision = layer->insData[SCATTER_DATA].lock()->getTensorDesc().getPrecision();
if (inDataPrecision != layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getPrecision())
THROW_IE_EXCEPTION << layer->name << " Precision should be equal for input tensors 'Data' and 'Updates'";
// Remove redundant dimensions
const SizeVector& data_dims = layer->insData[SCATTER_DATA].lock()->getTensorDesc().getDims();
if (data_dims.size() == 0 ||
(data_dims.size() == 1 && data_dims[0] == 1) ||
layer->insData[SCATTER_DATA].lock()->getTensorDesc().getLayout() == Layout::SCALAR)
THROW_IE_EXCEPTION << layer->name << " 'Data' tensor rank should be >= 1";
axis = layer->GetParamAsInt("axis", 0);
IE_ASSERT(-static_cast<int>(data_dims.size()) <= axis && axis < static_cast<int>(data_dims.size()))
<< layer->name << " Incorrect input parameters dimensions and axis number!";
if (axis < 0)
axis += data_dims.size();
SizeVector dst_dims = layer->outData[0]->getTensorDesc().getDims();
if (data_dims != dst_dims)
THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output dimensions!";
SizeVector idx_dims = layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getDims();
if (idx_dims.size() == 0 ||
(idx_dims.size() == 1 && idx_dims[0] == 1) ||
layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getLayout() == Layout::SCALAR)
THROW_IE_EXCEPTION << layer->name << " 'Indexes' tensor rank should be >= 1";
SizeVector upd_dims = layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getDims();
if (layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getLayout() == Layout::SCALAR)
THROW_IE_EXCEPTION << layer->name << " 'Indexes' tensor rank should be >= 1";
if (idx_dims != upd_dims)
THROW_IE_EXCEPTION << layer->name << " Incorrect number of 'indexes' and 'updates' tensors dimension";
for (size_t i = 0; i < idx_dims.size(); i++) {
if (i == static_cast<size_t>(axis)) continue;
if (idx_dims[i] > data_dims[i])
THROW_IE_EXCEPTION << layer->name << " Incorrect number of data and indexes dimensions!";
}
LayerConfig config;
DataConfig dataConfig, indexesConfig, updatesConfig;
Precision dataPrecision = layer->outData[0]->getTensorDesc().getPrecision();
dataConfig.desc = TensorDesc(dataPrecision, data_dims,
layer->insData[SCATTER_DATA].lock()->getTensorDesc().getLayout());
dataConfig.constant = false;
dataConfig.inPlace = 0;
config.inConfs.push_back(dataConfig);
indexesConfig.desc = TensorDesc(inIdxPrecision, idx_dims,
layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getLayout());
config.inConfs.push_back(indexesConfig);
updatesConfig.desc = TensorDesc(dataPrecision, upd_dims,
layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getLayout());
config.inConfs.push_back(updatesConfig);
DataConfig outConfig;
outConfig.desc = TensorDesc(dataPrecision, dst_dims, layer->outData[0]->getTensorDesc().getLayout());
outConfig.constant = false;
outConfig.inPlace = 0;
config.outConfs.push_back(outConfig);
config.dynBatchSupport = false;
confs.push_back(config);
} catch (InferenceEngine::details::InferenceEngineException &ex) {
errorMsg = ex.what();
}
}
StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
switch (inputs[SCATTER_INDEXES]->getTensorDesc().getPrecision()) {
case Precision::FP32:
scatter<float>(inputs[SCATTER_DATA], inputs[SCATTER_INDEXES], inputs[SCATTER_UPDATES], outputs[0]);
break;
case Precision::I32:
scatter<int32_t>(inputs[SCATTER_DATA], inputs[SCATTER_INDEXES], inputs[SCATTER_UPDATES], outputs[0]);
break;
default:
return GENERAL_ERROR;
}
return OK;
}
private:
template <typename index_t>
void scatter(Blob::Ptr data, Blob::Ptr indexes, Blob::Ptr updates, Blob::Ptr output) {
const uint8_t *src_data = data->cbuffer().as<const uint8_t *>() + data->getTensorDesc().getBlockingDesc().getOffsetPadding();
const index_t *src_index = indexes->cbuffer().as<const index_t *>() + indexes->getTensorDesc().getBlockingDesc().getOffsetPadding();
const uint8_t *src_updates = updates->cbuffer().as<const uint8_t *>() + updates->getTensorDesc().getBlockingDesc().getOffsetPadding();
uint8_t *dst_data = output->cbuffer().as<uint8_t*>() + output->getTensorDesc().getBlockingDesc().getOffsetPadding();
size_t data_size = data->getTensorDesc().getPrecision().size();
InferenceEngine::SizeVector index_dims = indexes->getTensorDesc().getDims();
InferenceEngine::SizeVector data_dims = data->getTensorDesc().getDims();
InferenceEngine::SizeVector dataStrides = data->getTensorDesc().getBlockingDesc().getStrides();
if (src_data != dst_data) {
parallel_nt(0, [&](const int ithr, const int nthr) {
size_t start = 0, end = 0;
splitter(output->size(), nthr, ithr, start, end);
size_t size = (end - start) * data_size;
start *= data_size;
simple_copy(dst_data + start, size, src_data + start, size);
});
}
parallel_nt(0, [&](const int ithr, const int nthr) {
int j;
size_t i, dst_idx = 0, start = 0, end = 0;
SizeVector counters(index_dims.size(), 0);
splitter(indexes->size(), nthr, ithr, start, end);
for (j = index_dims.size() - 1, i = start; j >= 0; j--) {
counters[j] = i % index_dims[j];
i /= index_dims[j];
}
for (i = 0; i < static_cast<size_t>(axis); ++i)
dst_idx += counters[i] * dataStrides[i];
for (i++; i < data_dims.size(); ++i)
dst_idx += counters[i] * dataStrides[i];
for (size_t iwork = start; iwork < end; iwork++) {
unsigned int idx = static_cast<unsigned int>(src_index[iwork]);
if (idx < data_dims[axis])
simple_copy(dst_data + data_size * (dst_idx + idx * dataStrides[axis]), data_size,
src_updates + iwork * data_size, data_size);
for (j = index_dims.size() - 1; j >= 0; j--) {
counters[j]++;
if (counters[j] < index_dims[j]) {
if (j != static_cast<size_t>(axis))
dst_idx += dataStrides[j];
break;
} else {
counters[j] = 0;
for (dst_idx = 0, i = 0; i < static_cast<size_t>(axis); ++i)
dst_idx += counters[i] * dataStrides[i];
for (i++; i < data_dims.size(); ++i)
dst_idx += counters[i] * dataStrides[i];
}
}
}
});
}
int axis = 0;
const size_t SCATTER_DATA = 0;
const size_t SCATTER_INDEXES = 1;
const size_t SCATTER_UPDATES = 2;
};
REG_FACTORY_FOR(ScatterImpl, ScatterUpdate);
} // namespace Cpu
} // namespace Extensions
} // namespace InferenceEngine

View File

@ -36,7 +36,6 @@ const auto ScatterNDUpdateCases = ::testing::Combine(
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
// open after ops support in ngraph merged
// INSTANTIATE_TEST_CASE_P(ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName);
} // namespace
} // namespace

View File

@ -41,7 +41,6 @@ const auto ScatterUpdateCase = ::testing::Combine(
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
// open after ngraph reference implementation merged
// INSTANTIATE_TEST_CASE_P(ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName);
INSTANTIATE_TEST_CASE_P(ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName);
} // namespace
} // namespace

View File

@ -10,6 +10,7 @@
#include <ngraph/opsets/opset1.hpp>
#include <ngraph/opsets/opset2.hpp>
#include <ngraph/opsets/opset3.hpp>
#include <ngraph/opsets/opset4.hpp>
#include "ngraph_functions/utils/data_utils.hpp"

View File

@ -13,10 +13,8 @@ std::shared_ptr<ngraph::Node> makeScatterNDUpdate(const ngraph::Output<Node> &in
const std::vector<size_t>& indices,
const ngraph::Output<Node> &update) {
auto indicesNode = std::make_shared<ngraph::opset1::Constant>(indicesType, indicesShape, indices);
// blocked by ngraph merge
// auto dtsNode = std::make_shared<ngraph::opset3::ScatterNDUpdate>(in, indicesNode, update);
// return dtsNode;
return nullptr;
auto dtsNode = std::make_shared<ngraph::opset4::ScatterNDUpdate>(in, indicesNode, update);
return dtsNode;
}
} // namespace builder

View File

@ -93,6 +93,8 @@
#include "op/group_conv.hpp"
#include "reference/detection_output.hpp"
#include "reference/scatter_nd_update.hpp"
#include "reference/scatter_update.hpp"
namespace ngraph
{
@ -1144,6 +1146,81 @@ protected:
break;
}
case OP_TYPEID::ScatterNDUpdate_v3:
{
const op::ScatterNDUpdate* scatterNDUpd =
static_cast<const op::v3::ScatterNDUpdate*>(&node);
auto idxType = scatterNDUpd->get_input_element_type(1);
if (idxType == element::i32)
{
reference::scatterNdUpdate<T, int32_t>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int32_t>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else if (idxType == element::i64)
{
reference::scatterNdUpdate<T, int64_t>(args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int64_t>(),
args[2]->get_data_ptr<const T>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else
{
throw ngraph_error(
"ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
}
break;
}
case OP_TYPEID::ScatterUpdate_v3:
{
const op::v3::ScatterUpdate* scatterUpd =
static_cast<const op::v3::ScatterUpdate*>(&node);
if (scatterUpd->get_input_element_type(3) != element::i64)
throw ngraph_error(
"ScatterNDUpdate layer support only i64 'axis' input precision!");
auto idxType = scatterUpd->get_input_element_type(1);
if (idxType == element::i32)
{
reference::scatterUpdate<T, int32_t, int64_t>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int32_t>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else if (idxType == element::i64)
{
reference::scatterUpdate<T, int64_t, int64_t>(
args[0]->get_data_ptr<const T>(),
args[1]->get_data_ptr<const int64_t>(),
args[2]->get_data_ptr<const T>(),
args[3]->get_data_ptr<const int64_t>(),
out[0]->get_data_ptr<T>(),
node.get_input_shape(0),
node.get_input_shape(1),
node.get_input_shape(2));
}
else
{
throw ngraph_error(
"ScatterUpdate layer support only i32 and i64 'indices' input precision!");
}
break;
}
// Fused Ops are not supported in interpreter. They need to be decomposed before execution
case OP_TYPEID::DepthToSpace:

View File

@ -37,4 +37,6 @@ NGRAPH_OP(EmbeddingSegmentsSum, op::v3)
NGRAPH_OP(ExtractImagePatches, op::v3)
NGRAPH_OP(ShapeOf, op::v3)
NGRAPH_OP(NonZero, op::v3)
NGRAPH_OP(ScatterNDUpdate, op::v3)
NGRAPH_OP(ScatterUpdate, op::v3)
#undef ID_SUFFIX

View File

@ -0,0 +1,63 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape.hpp"
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename dataType, typename indicesType>
void scatterNdUpdate(const dataType* inputData,
const indicesType* indices,
const dataType* updates,
dataType* outBuf,
const Shape& dataShape,
const Shape& indicesShape,
const Shape& updatesShape)
{
size_t numSlices = 1;
size_t sliceSize = 1;
for (size_t i = 0; i < indicesShape.size() - 1; i++)
{
numSlices *= indicesShape[i];
}
for (size_t i = indicesShape.size() - 1; i < updatesShape.size(); i++)
{
sliceSize *= updatesShape[i];
}
const size_t k = indicesShape.back();
std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
CoordinateTransform dataTransform{dataShape};
for (size_t i = 0; i < numSlices; i++)
{
Coordinate coord;
for (size_t j = 0; j < k; j++)
{
coord.push_back(indices[i * k + j]);
}
for (size_t j = k; j < dataShape.size(); j++)
{
coord.push_back(0);
}
const size_t startDataIdx = dataTransform.index(coord);
for (size_t j = 0; j < sliceSize; j++)
{
outBuf[startDataIdx + j] = updates[i * sliceSize + j];
}
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph

View File

@ -0,0 +1,86 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <string>
#include "ngraph/coordinate_transform.hpp"
#include "ngraph/shape.hpp"
using namespace ngraph;
namespace ngraph
{
namespace runtime
{
namespace reference
{
template <typename dataType, typename indicesType, typename axisType>
void scatterUpdate(const dataType* inputData,
const indicesType* indices,
const dataType* updates,
const axisType* _axis,
dataType* outBuf,
const Shape& dataShape,
const Shape& indicesShape,
const Shape& updatesShape)
{
int rank = static_cast<int>(dataShape.size());
if (_axis[0] < -rank || _axis[0] > rank - 1)
{
std::string error =
std::string("ScatterUpdate layer has out of bounds axis value: ") +
std::to_string(_axis[0]);
throw ngraph_error(error);
}
size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0];
CoordinateTransform indicesTransform{indicesShape};
Shape dataShapeIter = dataShape;
dataShapeIter.erase(dataShapeIter.begin() + axis);
CoordinateTransform dataTransfIter{dataShapeIter};
CoordinateTransform updateTransform{updatesShape};
CoordinateTransform dataTransform{dataShape};
std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
for (const Coordinate& indicesCoordIt : indicesTransform)
{
const size_t indicesIdx = indicesTransform.index(indicesCoordIt);
if (indices[indicesIdx] < 0)
{
std::string error =
std::string("ScatterUpdate layer has negative index value: ") +
std::to_string(indices[indicesIdx]);
throw ngraph_error(error);
}
const size_t idx = static_cast<size_t>(indices[indicesIdx]);
if (dataShape[axis] <= idx)
{
std::string error =
std::string("ScatterUpdate layer has out of bounds coordinate: ") +
std::to_string(idx) + " on 'data' input on " + std::to_string(axis) +
"th axis";
throw ngraph_error(error);
}
for (const Coordinate& dataCoordIt : dataTransfIter)
{
Coordinate dataCoord = dataCoordIt;
dataCoord.insert(dataCoord.begin() + axis, idx);
const size_t startIndices = dataTransform.index(dataCoord);
auto updCoord = dataCoordIt;
updCoord.insert(
updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end());
const size_t startUpd = updateTransform.index(updCoord);
outBuf[startIndices] = updates[startUpd];
}
}
}
} // namespace reference
} // namespace runtime
} // namespace ngraph