[NGraph] Add scatterNDUpdate and scatterUpdate reference implementations (#1494)
This commit is contained in:
parent
caa38130b9
commit
4054364fbf
@ -64,7 +64,6 @@ set(LAYERS
|
|||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_tree.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_tree.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/grn.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/grn.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/non_max_suppression.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/non_max_suppression.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/scatter.cpp
|
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/log_softmax.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/log_softmax.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/math.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/math.cpp
|
||||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/one_hot.cpp
|
${CMAKE_CURRENT_SOURCE_DIR}/nodes/one_hot.cpp
|
||||||
|
@ -51,7 +51,6 @@ MKLDNN_EXTENSION_NODE(FillImpl, Fill);
|
|||||||
MKLDNN_EXTENSION_NODE(UniqueImpl, Unique);
|
MKLDNN_EXTENSION_NODE(UniqueImpl, Unique);
|
||||||
MKLDNN_EXTENSION_NODE(PSROIPoolingImpl, PSROIPooling);
|
MKLDNN_EXTENSION_NODE(PSROIPoolingImpl, PSROIPooling);
|
||||||
MKLDNN_EXTENSION_NODE(DepthToSpaceImpl, DepthToSpace);
|
MKLDNN_EXTENSION_NODE(DepthToSpaceImpl, DepthToSpace);
|
||||||
MKLDNN_EXTENSION_NODE(ScatterImpl, ScatterUpdate);
|
|
||||||
MKLDNN_EXTENSION_NODE(OneHotImpl, OneHot);
|
MKLDNN_EXTENSION_NODE(OneHotImpl, OneHot);
|
||||||
MKLDNN_EXTENSION_NODE(BroadcastImpl, Broadcast);
|
MKLDNN_EXTENSION_NODE(BroadcastImpl, Broadcast);
|
||||||
MKLDNN_EXTENSION_NODE(ExperimentalSparseWeightedReduceImpl, ExperimentalSparseWeightedSum);
|
MKLDNN_EXTENSION_NODE(ExperimentalSparseWeightedReduceImpl, ExperimentalSparseWeightedSum);
|
||||||
|
@ -1,188 +0,0 @@
|
|||||||
// Copyright (C) 2018-2020 Intel Corporation
|
|
||||||
// SPDX-License-Identifier: Apache-2.0
|
|
||||||
//
|
|
||||||
|
|
||||||
#include "base.hpp"
|
|
||||||
|
|
||||||
#include <cmath>
|
|
||||||
#include <string>
|
|
||||||
#include <vector>
|
|
||||||
#include <cassert>
|
|
||||||
#include <algorithm>
|
|
||||||
#include <limits>
|
|
||||||
#include "ie_parallel.hpp"
|
|
||||||
#include "common/simple_copy.h"
|
|
||||||
|
|
||||||
namespace InferenceEngine {
|
|
||||||
namespace Extensions {
|
|
||||||
namespace Cpu {
|
|
||||||
|
|
||||||
class ScatterImpl: public ExtLayerBase {
|
|
||||||
public:
|
|
||||||
explicit ScatterImpl(const CNNLayer* layer) {
|
|
||||||
try {
|
|
||||||
if (layer->insData.size() != 3 || layer->outData.size() != 1)
|
|
||||||
THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output tensors!";
|
|
||||||
|
|
||||||
|
|
||||||
Precision inIdxPrecision = layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getPrecision();
|
|
||||||
if (inIdxPrecision != Precision::FP32 && inIdxPrecision != Precision::I32)
|
|
||||||
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indexes' precision. Only FP32 or I32 are supported!";
|
|
||||||
|
|
||||||
Precision inDataPrecision = layer->insData[SCATTER_DATA].lock()->getTensorDesc().getPrecision();
|
|
||||||
if (inDataPrecision != layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getPrecision())
|
|
||||||
THROW_IE_EXCEPTION << layer->name << " Precision should be equal for input tensors 'Data' and 'Updates'";
|
|
||||||
|
|
||||||
// Remove redundant dimensions
|
|
||||||
const SizeVector& data_dims = layer->insData[SCATTER_DATA].lock()->getTensorDesc().getDims();
|
|
||||||
if (data_dims.size() == 0 ||
|
|
||||||
(data_dims.size() == 1 && data_dims[0] == 1) ||
|
|
||||||
layer->insData[SCATTER_DATA].lock()->getTensorDesc().getLayout() == Layout::SCALAR)
|
|
||||||
THROW_IE_EXCEPTION << layer->name << " 'Data' tensor rank should be >= 1";
|
|
||||||
|
|
||||||
axis = layer->GetParamAsInt("axis", 0);
|
|
||||||
|
|
||||||
IE_ASSERT(-static_cast<int>(data_dims.size()) <= axis && axis < static_cast<int>(data_dims.size()))
|
|
||||||
<< layer->name << " Incorrect input parameters dimensions and axis number!";
|
|
||||||
|
|
||||||
if (axis < 0)
|
|
||||||
axis += data_dims.size();
|
|
||||||
|
|
||||||
SizeVector dst_dims = layer->outData[0]->getTensorDesc().getDims();
|
|
||||||
if (data_dims != dst_dims)
|
|
||||||
THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output dimensions!";
|
|
||||||
|
|
||||||
SizeVector idx_dims = layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getDims();
|
|
||||||
if (idx_dims.size() == 0 ||
|
|
||||||
(idx_dims.size() == 1 && idx_dims[0] == 1) ||
|
|
||||||
layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getLayout() == Layout::SCALAR)
|
|
||||||
THROW_IE_EXCEPTION << layer->name << " 'Indexes' tensor rank should be >= 1";
|
|
||||||
|
|
||||||
SizeVector upd_dims = layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getDims();
|
|
||||||
if (layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getLayout() == Layout::SCALAR)
|
|
||||||
THROW_IE_EXCEPTION << layer->name << " 'Indexes' tensor rank should be >= 1";
|
|
||||||
|
|
||||||
if (idx_dims != upd_dims)
|
|
||||||
THROW_IE_EXCEPTION << layer->name << " Incorrect number of 'indexes' and 'updates' tensors dimension";
|
|
||||||
|
|
||||||
for (size_t i = 0; i < idx_dims.size(); i++) {
|
|
||||||
if (i == static_cast<size_t>(axis)) continue;
|
|
||||||
if (idx_dims[i] > data_dims[i])
|
|
||||||
THROW_IE_EXCEPTION << layer->name << " Incorrect number of data and indexes dimensions!";
|
|
||||||
}
|
|
||||||
|
|
||||||
LayerConfig config;
|
|
||||||
DataConfig dataConfig, indexesConfig, updatesConfig;
|
|
||||||
Precision dataPrecision = layer->outData[0]->getTensorDesc().getPrecision();
|
|
||||||
dataConfig.desc = TensorDesc(dataPrecision, data_dims,
|
|
||||||
layer->insData[SCATTER_DATA].lock()->getTensorDesc().getLayout());
|
|
||||||
dataConfig.constant = false;
|
|
||||||
dataConfig.inPlace = 0;
|
|
||||||
config.inConfs.push_back(dataConfig);
|
|
||||||
indexesConfig.desc = TensorDesc(inIdxPrecision, idx_dims,
|
|
||||||
layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getLayout());
|
|
||||||
config.inConfs.push_back(indexesConfig);
|
|
||||||
updatesConfig.desc = TensorDesc(dataPrecision, upd_dims,
|
|
||||||
layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getLayout());
|
|
||||||
config.inConfs.push_back(updatesConfig);
|
|
||||||
|
|
||||||
DataConfig outConfig;
|
|
||||||
outConfig.desc = TensorDesc(dataPrecision, dst_dims, layer->outData[0]->getTensorDesc().getLayout());
|
|
||||||
outConfig.constant = false;
|
|
||||||
outConfig.inPlace = 0;
|
|
||||||
config.outConfs.push_back(outConfig);
|
|
||||||
config.dynBatchSupport = false;
|
|
||||||
confs.push_back(config);
|
|
||||||
} catch (InferenceEngine::details::InferenceEngineException &ex) {
|
|
||||||
errorMsg = ex.what();
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
|
|
||||||
switch (inputs[SCATTER_INDEXES]->getTensorDesc().getPrecision()) {
|
|
||||||
case Precision::FP32:
|
|
||||||
scatter<float>(inputs[SCATTER_DATA], inputs[SCATTER_INDEXES], inputs[SCATTER_UPDATES], outputs[0]);
|
|
||||||
break;
|
|
||||||
case Precision::I32:
|
|
||||||
scatter<int32_t>(inputs[SCATTER_DATA], inputs[SCATTER_INDEXES], inputs[SCATTER_UPDATES], outputs[0]);
|
|
||||||
break;
|
|
||||||
default:
|
|
||||||
return GENERAL_ERROR;
|
|
||||||
}
|
|
||||||
|
|
||||||
return OK;
|
|
||||||
}
|
|
||||||
|
|
||||||
private:
|
|
||||||
template <typename index_t>
|
|
||||||
void scatter(Blob::Ptr data, Blob::Ptr indexes, Blob::Ptr updates, Blob::Ptr output) {
|
|
||||||
const uint8_t *src_data = data->cbuffer().as<const uint8_t *>() + data->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
|
||||||
const index_t *src_index = indexes->cbuffer().as<const index_t *>() + indexes->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
|
||||||
const uint8_t *src_updates = updates->cbuffer().as<const uint8_t *>() + updates->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
|
||||||
uint8_t *dst_data = output->cbuffer().as<uint8_t*>() + output->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
|
||||||
size_t data_size = data->getTensorDesc().getPrecision().size();
|
|
||||||
|
|
||||||
InferenceEngine::SizeVector index_dims = indexes->getTensorDesc().getDims();
|
|
||||||
InferenceEngine::SizeVector data_dims = data->getTensorDesc().getDims();
|
|
||||||
InferenceEngine::SizeVector dataStrides = data->getTensorDesc().getBlockingDesc().getStrides();
|
|
||||||
|
|
||||||
if (src_data != dst_data) {
|
|
||||||
parallel_nt(0, [&](const int ithr, const int nthr) {
|
|
||||||
size_t start = 0, end = 0;
|
|
||||||
splitter(output->size(), nthr, ithr, start, end);
|
|
||||||
size_t size = (end - start) * data_size;
|
|
||||||
start *= data_size;
|
|
||||||
simple_copy(dst_data + start, size, src_data + start, size);
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
parallel_nt(0, [&](const int ithr, const int nthr) {
|
|
||||||
int j;
|
|
||||||
size_t i, dst_idx = 0, start = 0, end = 0;
|
|
||||||
SizeVector counters(index_dims.size(), 0);
|
|
||||||
splitter(indexes->size(), nthr, ithr, start, end);
|
|
||||||
for (j = index_dims.size() - 1, i = start; j >= 0; j--) {
|
|
||||||
counters[j] = i % index_dims[j];
|
|
||||||
i /= index_dims[j];
|
|
||||||
}
|
|
||||||
|
|
||||||
for (i = 0; i < static_cast<size_t>(axis); ++i)
|
|
||||||
dst_idx += counters[i] * dataStrides[i];
|
|
||||||
for (i++; i < data_dims.size(); ++i)
|
|
||||||
dst_idx += counters[i] * dataStrides[i];
|
|
||||||
|
|
||||||
for (size_t iwork = start; iwork < end; iwork++) {
|
|
||||||
unsigned int idx = static_cast<unsigned int>(src_index[iwork]);
|
|
||||||
if (idx < data_dims[axis])
|
|
||||||
simple_copy(dst_data + data_size * (dst_idx + idx * dataStrides[axis]), data_size,
|
|
||||||
src_updates + iwork * data_size, data_size);
|
|
||||||
|
|
||||||
for (j = index_dims.size() - 1; j >= 0; j--) {
|
|
||||||
counters[j]++;
|
|
||||||
if (counters[j] < index_dims[j]) {
|
|
||||||
if (j != static_cast<size_t>(axis))
|
|
||||||
dst_idx += dataStrides[j];
|
|
||||||
break;
|
|
||||||
} else {
|
|
||||||
counters[j] = 0;
|
|
||||||
for (dst_idx = 0, i = 0; i < static_cast<size_t>(axis); ++i)
|
|
||||||
dst_idx += counters[i] * dataStrides[i];
|
|
||||||
for (i++; i < data_dims.size(); ++i)
|
|
||||||
dst_idx += counters[i] * dataStrides[i];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
});
|
|
||||||
}
|
|
||||||
|
|
||||||
int axis = 0;
|
|
||||||
const size_t SCATTER_DATA = 0;
|
|
||||||
const size_t SCATTER_INDEXES = 1;
|
|
||||||
const size_t SCATTER_UPDATES = 2;
|
|
||||||
};
|
|
||||||
|
|
||||||
REG_FACTORY_FOR(ScatterImpl, ScatterUpdate);
|
|
||||||
|
|
||||||
} // namespace Cpu
|
|
||||||
} // namespace Extensions
|
|
||||||
} // namespace InferenceEngine
|
|
@ -36,7 +36,6 @@ const auto ScatterNDUpdateCases = ::testing::Combine(
|
|||||||
::testing::Values(CommonTestUtils::DEVICE_CPU)
|
::testing::Values(CommonTestUtils::DEVICE_CPU)
|
||||||
);
|
);
|
||||||
|
|
||||||
// open after ops support in ngraph merged
|
INSTANTIATE_TEST_CASE_P(ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName);
|
||||||
// INSTANTIATE_TEST_CASE_P(ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName);
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
@ -41,7 +41,6 @@ const auto ScatterUpdateCase = ::testing::Combine(
|
|||||||
::testing::Values(CommonTestUtils::DEVICE_CPU)
|
::testing::Values(CommonTestUtils::DEVICE_CPU)
|
||||||
);
|
);
|
||||||
|
|
||||||
// open after ngraph reference implementation merged
|
INSTANTIATE_TEST_CASE_P(ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName);
|
||||||
// INSTANTIATE_TEST_CASE_P(ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName);
|
|
||||||
|
|
||||||
} // namespace
|
} // namespace
|
@ -10,6 +10,7 @@
|
|||||||
#include <ngraph/opsets/opset1.hpp>
|
#include <ngraph/opsets/opset1.hpp>
|
||||||
#include <ngraph/opsets/opset2.hpp>
|
#include <ngraph/opsets/opset2.hpp>
|
||||||
#include <ngraph/opsets/opset3.hpp>
|
#include <ngraph/opsets/opset3.hpp>
|
||||||
|
#include <ngraph/opsets/opset4.hpp>
|
||||||
|
|
||||||
#include "ngraph_functions/utils/data_utils.hpp"
|
#include "ngraph_functions/utils/data_utils.hpp"
|
||||||
|
|
||||||
|
@ -13,10 +13,8 @@ std::shared_ptr<ngraph::Node> makeScatterNDUpdate(const ngraph::Output<Node> &in
|
|||||||
const std::vector<size_t>& indices,
|
const std::vector<size_t>& indices,
|
||||||
const ngraph::Output<Node> &update) {
|
const ngraph::Output<Node> &update) {
|
||||||
auto indicesNode = std::make_shared<ngraph::opset1::Constant>(indicesType, indicesShape, indices);
|
auto indicesNode = std::make_shared<ngraph::opset1::Constant>(indicesType, indicesShape, indices);
|
||||||
// blocked by ngraph merge
|
auto dtsNode = std::make_shared<ngraph::opset4::ScatterNDUpdate>(in, indicesNode, update);
|
||||||
// auto dtsNode = std::make_shared<ngraph::opset3::ScatterNDUpdate>(in, indicesNode, update);
|
return dtsNode;
|
||||||
// return dtsNode;
|
|
||||||
return nullptr;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
} // namespace builder
|
} // namespace builder
|
||||||
|
@ -93,6 +93,8 @@
|
|||||||
#include "op/group_conv.hpp"
|
#include "op/group_conv.hpp"
|
||||||
|
|
||||||
#include "reference/detection_output.hpp"
|
#include "reference/detection_output.hpp"
|
||||||
|
#include "reference/scatter_nd_update.hpp"
|
||||||
|
#include "reference/scatter_update.hpp"
|
||||||
|
|
||||||
namespace ngraph
|
namespace ngraph
|
||||||
{
|
{
|
||||||
@ -1144,6 +1146,81 @@ protected:
|
|||||||
|
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
case OP_TYPEID::ScatterNDUpdate_v3:
|
||||||
|
{
|
||||||
|
const op::ScatterNDUpdate* scatterNDUpd =
|
||||||
|
static_cast<const op::v3::ScatterNDUpdate*>(&node);
|
||||||
|
auto idxType = scatterNDUpd->get_input_element_type(1);
|
||||||
|
if (idxType == element::i32)
|
||||||
|
{
|
||||||
|
reference::scatterNdUpdate<T, int32_t>(args[0]->get_data_ptr<const T>(),
|
||||||
|
args[1]->get_data_ptr<const int32_t>(),
|
||||||
|
args[2]->get_data_ptr<const T>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_input_shape(1),
|
||||||
|
node.get_input_shape(2));
|
||||||
|
}
|
||||||
|
else if (idxType == element::i64)
|
||||||
|
{
|
||||||
|
reference::scatterNdUpdate<T, int64_t>(args[0]->get_data_ptr<const T>(),
|
||||||
|
args[1]->get_data_ptr<const int64_t>(),
|
||||||
|
args[2]->get_data_ptr<const T>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_input_shape(1),
|
||||||
|
node.get_input_shape(2));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw ngraph_error(
|
||||||
|
"ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
|
||||||
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
case OP_TYPEID::ScatterUpdate_v3:
|
||||||
|
{
|
||||||
|
const op::v3::ScatterUpdate* scatterUpd =
|
||||||
|
static_cast<const op::v3::ScatterUpdate*>(&node);
|
||||||
|
|
||||||
|
if (scatterUpd->get_input_element_type(3) != element::i64)
|
||||||
|
throw ngraph_error(
|
||||||
|
"ScatterNDUpdate layer support only i64 'axis' input precision!");
|
||||||
|
|
||||||
|
auto idxType = scatterUpd->get_input_element_type(1);
|
||||||
|
if (idxType == element::i32)
|
||||||
|
{
|
||||||
|
reference::scatterUpdate<T, int32_t, int64_t>(
|
||||||
|
args[0]->get_data_ptr<const T>(),
|
||||||
|
args[1]->get_data_ptr<const int32_t>(),
|
||||||
|
args[2]->get_data_ptr<const T>(),
|
||||||
|
args[3]->get_data_ptr<const int64_t>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_input_shape(1),
|
||||||
|
node.get_input_shape(2));
|
||||||
|
}
|
||||||
|
else if (idxType == element::i64)
|
||||||
|
{
|
||||||
|
reference::scatterUpdate<T, int64_t, int64_t>(
|
||||||
|
args[0]->get_data_ptr<const T>(),
|
||||||
|
args[1]->get_data_ptr<const int64_t>(),
|
||||||
|
args[2]->get_data_ptr<const T>(),
|
||||||
|
args[3]->get_data_ptr<const int64_t>(),
|
||||||
|
out[0]->get_data_ptr<T>(),
|
||||||
|
node.get_input_shape(0),
|
||||||
|
node.get_input_shape(1),
|
||||||
|
node.get_input_shape(2));
|
||||||
|
}
|
||||||
|
else
|
||||||
|
{
|
||||||
|
throw ngraph_error(
|
||||||
|
"ScatterUpdate layer support only i32 and i64 'indices' input precision!");
|
||||||
|
}
|
||||||
|
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
|
||||||
// Fused Ops are not supported in interpreter. They need to be decomposed before execution
|
// Fused Ops are not supported in interpreter. They need to be decomposed before execution
|
||||||
case OP_TYPEID::DepthToSpace:
|
case OP_TYPEID::DepthToSpace:
|
||||||
|
@ -37,4 +37,6 @@ NGRAPH_OP(EmbeddingSegmentsSum, op::v3)
|
|||||||
NGRAPH_OP(ExtractImagePatches, op::v3)
|
NGRAPH_OP(ExtractImagePatches, op::v3)
|
||||||
NGRAPH_OP(ShapeOf, op::v3)
|
NGRAPH_OP(ShapeOf, op::v3)
|
||||||
NGRAPH_OP(NonZero, op::v3)
|
NGRAPH_OP(NonZero, op::v3)
|
||||||
|
NGRAPH_OP(ScatterNDUpdate, op::v3)
|
||||||
|
NGRAPH_OP(ScatterUpdate, op::v3)
|
||||||
#undef ID_SUFFIX
|
#undef ID_SUFFIX
|
||||||
|
@ -0,0 +1,63 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include "ngraph/coordinate_transform.hpp"
|
||||||
|
#include "ngraph/shape.hpp"
|
||||||
|
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
namespace ngraph
|
||||||
|
{
|
||||||
|
namespace runtime
|
||||||
|
{
|
||||||
|
namespace reference
|
||||||
|
{
|
||||||
|
template <typename dataType, typename indicesType>
|
||||||
|
void scatterNdUpdate(const dataType* inputData,
|
||||||
|
const indicesType* indices,
|
||||||
|
const dataType* updates,
|
||||||
|
dataType* outBuf,
|
||||||
|
const Shape& dataShape,
|
||||||
|
const Shape& indicesShape,
|
||||||
|
const Shape& updatesShape)
|
||||||
|
{
|
||||||
|
size_t numSlices = 1;
|
||||||
|
size_t sliceSize = 1;
|
||||||
|
for (size_t i = 0; i < indicesShape.size() - 1; i++)
|
||||||
|
{
|
||||||
|
numSlices *= indicesShape[i];
|
||||||
|
}
|
||||||
|
for (size_t i = indicesShape.size() - 1; i < updatesShape.size(); i++)
|
||||||
|
{
|
||||||
|
sliceSize *= updatesShape[i];
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t k = indicesShape.back();
|
||||||
|
std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
|
||||||
|
CoordinateTransform dataTransform{dataShape};
|
||||||
|
|
||||||
|
for (size_t i = 0; i < numSlices; i++)
|
||||||
|
{
|
||||||
|
Coordinate coord;
|
||||||
|
for (size_t j = 0; j < k; j++)
|
||||||
|
{
|
||||||
|
coord.push_back(indices[i * k + j]);
|
||||||
|
}
|
||||||
|
for (size_t j = k; j < dataShape.size(); j++)
|
||||||
|
{
|
||||||
|
coord.push_back(0);
|
||||||
|
}
|
||||||
|
|
||||||
|
const size_t startDataIdx = dataTransform.index(coord);
|
||||||
|
for (size_t j = 0; j < sliceSize; j++)
|
||||||
|
{
|
||||||
|
outBuf[startDataIdx + j] = updates[i * sliceSize + j];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace reference
|
||||||
|
} // namespace runtime
|
||||||
|
} // namespace ngraph
|
86
ngraph/test/runtime/interpreter/reference/scatter_update.hpp
Normal file
86
ngraph/test/runtime/interpreter/reference/scatter_update.hpp
Normal file
@ -0,0 +1,86 @@
|
|||||||
|
// Copyright (C) 2020 Intel Corporation
|
||||||
|
// SPDX-License-Identifier: Apache-2.0
|
||||||
|
//
|
||||||
|
|
||||||
|
#pragma once
|
||||||
|
|
||||||
|
#include <string>
|
||||||
|
#include "ngraph/coordinate_transform.hpp"
|
||||||
|
#include "ngraph/shape.hpp"
|
||||||
|
|
||||||
|
using namespace ngraph;
|
||||||
|
|
||||||
|
namespace ngraph
|
||||||
|
{
|
||||||
|
namespace runtime
|
||||||
|
{
|
||||||
|
namespace reference
|
||||||
|
{
|
||||||
|
template <typename dataType, typename indicesType, typename axisType>
|
||||||
|
void scatterUpdate(const dataType* inputData,
|
||||||
|
const indicesType* indices,
|
||||||
|
const dataType* updates,
|
||||||
|
const axisType* _axis,
|
||||||
|
dataType* outBuf,
|
||||||
|
const Shape& dataShape,
|
||||||
|
const Shape& indicesShape,
|
||||||
|
const Shape& updatesShape)
|
||||||
|
{
|
||||||
|
int rank = static_cast<int>(dataShape.size());
|
||||||
|
if (_axis[0] < -rank || _axis[0] > rank - 1)
|
||||||
|
{
|
||||||
|
std::string error =
|
||||||
|
std::string("ScatterUpdate layer has out of bounds axis value: ") +
|
||||||
|
std::to_string(_axis[0]);
|
||||||
|
throw ngraph_error(error);
|
||||||
|
}
|
||||||
|
size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0];
|
||||||
|
CoordinateTransform indicesTransform{indicesShape};
|
||||||
|
|
||||||
|
Shape dataShapeIter = dataShape;
|
||||||
|
dataShapeIter.erase(dataShapeIter.begin() + axis);
|
||||||
|
CoordinateTransform dataTransfIter{dataShapeIter};
|
||||||
|
|
||||||
|
CoordinateTransform updateTransform{updatesShape};
|
||||||
|
CoordinateTransform dataTransform{dataShape};
|
||||||
|
|
||||||
|
std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
|
||||||
|
|
||||||
|
for (const Coordinate& indicesCoordIt : indicesTransform)
|
||||||
|
{
|
||||||
|
const size_t indicesIdx = indicesTransform.index(indicesCoordIt);
|
||||||
|
|
||||||
|
if (indices[indicesIdx] < 0)
|
||||||
|
{
|
||||||
|
std::string error =
|
||||||
|
std::string("ScatterUpdate layer has negative index value: ") +
|
||||||
|
std::to_string(indices[indicesIdx]);
|
||||||
|
throw ngraph_error(error);
|
||||||
|
}
|
||||||
|
const size_t idx = static_cast<size_t>(indices[indicesIdx]);
|
||||||
|
if (dataShape[axis] <= idx)
|
||||||
|
{
|
||||||
|
std::string error =
|
||||||
|
std::string("ScatterUpdate layer has out of bounds coordinate: ") +
|
||||||
|
std::to_string(idx) + " on 'data' input on " + std::to_string(axis) +
|
||||||
|
"th axis";
|
||||||
|
throw ngraph_error(error);
|
||||||
|
}
|
||||||
|
|
||||||
|
for (const Coordinate& dataCoordIt : dataTransfIter)
|
||||||
|
{
|
||||||
|
Coordinate dataCoord = dataCoordIt;
|
||||||
|
dataCoord.insert(dataCoord.begin() + axis, idx);
|
||||||
|
const size_t startIndices = dataTransform.index(dataCoord);
|
||||||
|
|
||||||
|
auto updCoord = dataCoordIt;
|
||||||
|
updCoord.insert(
|
||||||
|
updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end());
|
||||||
|
const size_t startUpd = updateTransform.index(updCoord);
|
||||||
|
outBuf[startIndices] = updates[startUpd];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} // namespace reference
|
||||||
|
} // namespace runtime
|
||||||
|
} // namespace ngraph
|
Loading…
Reference in New Issue
Block a user