[NGraph] Add scatterNDUpdate and scatterUpdate reference implementations (#1494)
This commit is contained in:
parent
caa38130b9
commit
4054364fbf
@ -64,7 +64,6 @@ set(LAYERS
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/gather_tree.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/grn.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/non_max_suppression.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/scatter.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/log_softmax.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/math.cpp
|
||||
${CMAKE_CURRENT_SOURCE_DIR}/nodes/one_hot.cpp
|
||||
|
@ -51,7 +51,6 @@ MKLDNN_EXTENSION_NODE(FillImpl, Fill);
|
||||
MKLDNN_EXTENSION_NODE(UniqueImpl, Unique);
|
||||
MKLDNN_EXTENSION_NODE(PSROIPoolingImpl, PSROIPooling);
|
||||
MKLDNN_EXTENSION_NODE(DepthToSpaceImpl, DepthToSpace);
|
||||
MKLDNN_EXTENSION_NODE(ScatterImpl, ScatterUpdate);
|
||||
MKLDNN_EXTENSION_NODE(OneHotImpl, OneHot);
|
||||
MKLDNN_EXTENSION_NODE(BroadcastImpl, Broadcast);
|
||||
MKLDNN_EXTENSION_NODE(ExperimentalSparseWeightedReduceImpl, ExperimentalSparseWeightedSum);
|
||||
|
@ -1,188 +0,0 @@
|
||||
// Copyright (C) 2018-2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#include "base.hpp"
|
||||
|
||||
#include <cmath>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
#include <cassert>
|
||||
#include <algorithm>
|
||||
#include <limits>
|
||||
#include "ie_parallel.hpp"
|
||||
#include "common/simple_copy.h"
|
||||
|
||||
namespace InferenceEngine {
|
||||
namespace Extensions {
|
||||
namespace Cpu {
|
||||
|
||||
class ScatterImpl: public ExtLayerBase {
|
||||
public:
|
||||
explicit ScatterImpl(const CNNLayer* layer) {
|
||||
try {
|
||||
if (layer->insData.size() != 3 || layer->outData.size() != 1)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output tensors!";
|
||||
|
||||
|
||||
Precision inIdxPrecision = layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getPrecision();
|
||||
if (inIdxPrecision != Precision::FP32 && inIdxPrecision != Precision::I32)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indexes' precision. Only FP32 or I32 are supported!";
|
||||
|
||||
Precision inDataPrecision = layer->insData[SCATTER_DATA].lock()->getTensorDesc().getPrecision();
|
||||
if (inDataPrecision != layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getPrecision())
|
||||
THROW_IE_EXCEPTION << layer->name << " Precision should be equal for input tensors 'Data' and 'Updates'";
|
||||
|
||||
// Remove redundant dimensions
|
||||
const SizeVector& data_dims = layer->insData[SCATTER_DATA].lock()->getTensorDesc().getDims();
|
||||
if (data_dims.size() == 0 ||
|
||||
(data_dims.size() == 1 && data_dims[0] == 1) ||
|
||||
layer->insData[SCATTER_DATA].lock()->getTensorDesc().getLayout() == Layout::SCALAR)
|
||||
THROW_IE_EXCEPTION << layer->name << " 'Data' tensor rank should be >= 1";
|
||||
|
||||
axis = layer->GetParamAsInt("axis", 0);
|
||||
|
||||
IE_ASSERT(-static_cast<int>(data_dims.size()) <= axis && axis < static_cast<int>(data_dims.size()))
|
||||
<< layer->name << " Incorrect input parameters dimensions and axis number!";
|
||||
|
||||
if (axis < 0)
|
||||
axis += data_dims.size();
|
||||
|
||||
SizeVector dst_dims = layer->outData[0]->getTensorDesc().getDims();
|
||||
if (data_dims != dst_dims)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect number of input/output dimensions!";
|
||||
|
||||
SizeVector idx_dims = layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getDims();
|
||||
if (idx_dims.size() == 0 ||
|
||||
(idx_dims.size() == 1 && idx_dims[0] == 1) ||
|
||||
layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getLayout() == Layout::SCALAR)
|
||||
THROW_IE_EXCEPTION << layer->name << " 'Indexes' tensor rank should be >= 1";
|
||||
|
||||
SizeVector upd_dims = layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getDims();
|
||||
if (layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getLayout() == Layout::SCALAR)
|
||||
THROW_IE_EXCEPTION << layer->name << " 'Indexes' tensor rank should be >= 1";
|
||||
|
||||
if (idx_dims != upd_dims)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect number of 'indexes' and 'updates' tensors dimension";
|
||||
|
||||
for (size_t i = 0; i < idx_dims.size(); i++) {
|
||||
if (i == static_cast<size_t>(axis)) continue;
|
||||
if (idx_dims[i] > data_dims[i])
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect number of data and indexes dimensions!";
|
||||
}
|
||||
|
||||
LayerConfig config;
|
||||
DataConfig dataConfig, indexesConfig, updatesConfig;
|
||||
Precision dataPrecision = layer->outData[0]->getTensorDesc().getPrecision();
|
||||
dataConfig.desc = TensorDesc(dataPrecision, data_dims,
|
||||
layer->insData[SCATTER_DATA].lock()->getTensorDesc().getLayout());
|
||||
dataConfig.constant = false;
|
||||
dataConfig.inPlace = 0;
|
||||
config.inConfs.push_back(dataConfig);
|
||||
indexesConfig.desc = TensorDesc(inIdxPrecision, idx_dims,
|
||||
layer->insData[SCATTER_INDEXES].lock()->getTensorDesc().getLayout());
|
||||
config.inConfs.push_back(indexesConfig);
|
||||
updatesConfig.desc = TensorDesc(dataPrecision, upd_dims,
|
||||
layer->insData[SCATTER_UPDATES].lock()->getTensorDesc().getLayout());
|
||||
config.inConfs.push_back(updatesConfig);
|
||||
|
||||
DataConfig outConfig;
|
||||
outConfig.desc = TensorDesc(dataPrecision, dst_dims, layer->outData[0]->getTensorDesc().getLayout());
|
||||
outConfig.constant = false;
|
||||
outConfig.inPlace = 0;
|
||||
config.outConfs.push_back(outConfig);
|
||||
config.dynBatchSupport = false;
|
||||
confs.push_back(config);
|
||||
} catch (InferenceEngine::details::InferenceEngineException &ex) {
|
||||
errorMsg = ex.what();
|
||||
}
|
||||
}
|
||||
|
||||
StatusCode execute(std::vector<Blob::Ptr>& inputs, std::vector<Blob::Ptr>& outputs, ResponseDesc *resp) noexcept override {
|
||||
switch (inputs[SCATTER_INDEXES]->getTensorDesc().getPrecision()) {
|
||||
case Precision::FP32:
|
||||
scatter<float>(inputs[SCATTER_DATA], inputs[SCATTER_INDEXES], inputs[SCATTER_UPDATES], outputs[0]);
|
||||
break;
|
||||
case Precision::I32:
|
||||
scatter<int32_t>(inputs[SCATTER_DATA], inputs[SCATTER_INDEXES], inputs[SCATTER_UPDATES], outputs[0]);
|
||||
break;
|
||||
default:
|
||||
return GENERAL_ERROR;
|
||||
}
|
||||
|
||||
return OK;
|
||||
}
|
||||
|
||||
private:
|
||||
template <typename index_t>
|
||||
void scatter(Blob::Ptr data, Blob::Ptr indexes, Blob::Ptr updates, Blob::Ptr output) {
|
||||
const uint8_t *src_data = data->cbuffer().as<const uint8_t *>() + data->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
const index_t *src_index = indexes->cbuffer().as<const index_t *>() + indexes->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
const uint8_t *src_updates = updates->cbuffer().as<const uint8_t *>() + updates->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
uint8_t *dst_data = output->cbuffer().as<uint8_t*>() + output->getTensorDesc().getBlockingDesc().getOffsetPadding();
|
||||
size_t data_size = data->getTensorDesc().getPrecision().size();
|
||||
|
||||
InferenceEngine::SizeVector index_dims = indexes->getTensorDesc().getDims();
|
||||
InferenceEngine::SizeVector data_dims = data->getTensorDesc().getDims();
|
||||
InferenceEngine::SizeVector dataStrides = data->getTensorDesc().getBlockingDesc().getStrides();
|
||||
|
||||
if (src_data != dst_data) {
|
||||
parallel_nt(0, [&](const int ithr, const int nthr) {
|
||||
size_t start = 0, end = 0;
|
||||
splitter(output->size(), nthr, ithr, start, end);
|
||||
size_t size = (end - start) * data_size;
|
||||
start *= data_size;
|
||||
simple_copy(dst_data + start, size, src_data + start, size);
|
||||
});
|
||||
}
|
||||
|
||||
parallel_nt(0, [&](const int ithr, const int nthr) {
|
||||
int j;
|
||||
size_t i, dst_idx = 0, start = 0, end = 0;
|
||||
SizeVector counters(index_dims.size(), 0);
|
||||
splitter(indexes->size(), nthr, ithr, start, end);
|
||||
for (j = index_dims.size() - 1, i = start; j >= 0; j--) {
|
||||
counters[j] = i % index_dims[j];
|
||||
i /= index_dims[j];
|
||||
}
|
||||
|
||||
for (i = 0; i < static_cast<size_t>(axis); ++i)
|
||||
dst_idx += counters[i] * dataStrides[i];
|
||||
for (i++; i < data_dims.size(); ++i)
|
||||
dst_idx += counters[i] * dataStrides[i];
|
||||
|
||||
for (size_t iwork = start; iwork < end; iwork++) {
|
||||
unsigned int idx = static_cast<unsigned int>(src_index[iwork]);
|
||||
if (idx < data_dims[axis])
|
||||
simple_copy(dst_data + data_size * (dst_idx + idx * dataStrides[axis]), data_size,
|
||||
src_updates + iwork * data_size, data_size);
|
||||
|
||||
for (j = index_dims.size() - 1; j >= 0; j--) {
|
||||
counters[j]++;
|
||||
if (counters[j] < index_dims[j]) {
|
||||
if (j != static_cast<size_t>(axis))
|
||||
dst_idx += dataStrides[j];
|
||||
break;
|
||||
} else {
|
||||
counters[j] = 0;
|
||||
for (dst_idx = 0, i = 0; i < static_cast<size_t>(axis); ++i)
|
||||
dst_idx += counters[i] * dataStrides[i];
|
||||
for (i++; i < data_dims.size(); ++i)
|
||||
dst_idx += counters[i] * dataStrides[i];
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
int axis = 0;
|
||||
const size_t SCATTER_DATA = 0;
|
||||
const size_t SCATTER_INDEXES = 1;
|
||||
const size_t SCATTER_UPDATES = 2;
|
||||
};
|
||||
|
||||
REG_FACTORY_FOR(ScatterImpl, ScatterUpdate);
|
||||
|
||||
} // namespace Cpu
|
||||
} // namespace Extensions
|
||||
} // namespace InferenceEngine
|
@ -36,7 +36,6 @@ const auto ScatterNDUpdateCases = ::testing::Combine(
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)
|
||||
);
|
||||
|
||||
// open after ops support in ngraph merged
|
||||
// INSTANTIATE_TEST_CASE_P(ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_CASE_P(ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
@ -41,7 +41,6 @@ const auto ScatterUpdateCase = ::testing::Combine(
|
||||
::testing::Values(CommonTestUtils::DEVICE_CPU)
|
||||
);
|
||||
|
||||
// open after ngraph reference implementation merged
|
||||
// INSTANTIATE_TEST_CASE_P(ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName);
|
||||
INSTANTIATE_TEST_CASE_P(ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName);
|
||||
|
||||
} // namespace
|
||||
} // namespace
|
||||
|
@ -10,6 +10,7 @@
|
||||
#include <ngraph/opsets/opset1.hpp>
|
||||
#include <ngraph/opsets/opset2.hpp>
|
||||
#include <ngraph/opsets/opset3.hpp>
|
||||
#include <ngraph/opsets/opset4.hpp>
|
||||
|
||||
#include "ngraph_functions/utils/data_utils.hpp"
|
||||
|
||||
|
@ -13,10 +13,8 @@ std::shared_ptr<ngraph::Node> makeScatterNDUpdate(const ngraph::Output<Node> &in
|
||||
const std::vector<size_t>& indices,
|
||||
const ngraph::Output<Node> &update) {
|
||||
auto indicesNode = std::make_shared<ngraph::opset1::Constant>(indicesType, indicesShape, indices);
|
||||
// blocked by ngraph merge
|
||||
// auto dtsNode = std::make_shared<ngraph::opset3::ScatterNDUpdate>(in, indicesNode, update);
|
||||
// return dtsNode;
|
||||
return nullptr;
|
||||
auto dtsNode = std::make_shared<ngraph::opset4::ScatterNDUpdate>(in, indicesNode, update);
|
||||
return dtsNode;
|
||||
}
|
||||
|
||||
} // namespace builder
|
||||
|
@ -93,6 +93,8 @@
|
||||
#include "op/group_conv.hpp"
|
||||
|
||||
#include "reference/detection_output.hpp"
|
||||
#include "reference/scatter_nd_update.hpp"
|
||||
#include "reference/scatter_update.hpp"
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
@ -1144,6 +1146,81 @@ protected:
|
||||
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::ScatterNDUpdate_v3:
|
||||
{
|
||||
const op::ScatterNDUpdate* scatterNDUpd =
|
||||
static_cast<const op::v3::ScatterNDUpdate*>(&node);
|
||||
auto idxType = scatterNDUpd->get_input_element_type(1);
|
||||
if (idxType == element::i32)
|
||||
{
|
||||
reference::scatterNdUpdate<T, int32_t>(args[0]->get_data_ptr<const T>(),
|
||||
args[1]->get_data_ptr<const int32_t>(),
|
||||
args[2]->get_data_ptr<const T>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
node.get_input_shape(0),
|
||||
node.get_input_shape(1),
|
||||
node.get_input_shape(2));
|
||||
}
|
||||
else if (idxType == element::i64)
|
||||
{
|
||||
reference::scatterNdUpdate<T, int64_t>(args[0]->get_data_ptr<const T>(),
|
||||
args[1]->get_data_ptr<const int64_t>(),
|
||||
args[2]->get_data_ptr<const T>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
node.get_input_shape(0),
|
||||
node.get_input_shape(1),
|
||||
node.get_input_shape(2));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error(
|
||||
"ScatterNDUpdate layer support only i32 and i64 'indices' input precision!");
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
case OP_TYPEID::ScatterUpdate_v3:
|
||||
{
|
||||
const op::v3::ScatterUpdate* scatterUpd =
|
||||
static_cast<const op::v3::ScatterUpdate*>(&node);
|
||||
|
||||
if (scatterUpd->get_input_element_type(3) != element::i64)
|
||||
throw ngraph_error(
|
||||
"ScatterNDUpdate layer support only i64 'axis' input precision!");
|
||||
|
||||
auto idxType = scatterUpd->get_input_element_type(1);
|
||||
if (idxType == element::i32)
|
||||
{
|
||||
reference::scatterUpdate<T, int32_t, int64_t>(
|
||||
args[0]->get_data_ptr<const T>(),
|
||||
args[1]->get_data_ptr<const int32_t>(),
|
||||
args[2]->get_data_ptr<const T>(),
|
||||
args[3]->get_data_ptr<const int64_t>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
node.get_input_shape(0),
|
||||
node.get_input_shape(1),
|
||||
node.get_input_shape(2));
|
||||
}
|
||||
else if (idxType == element::i64)
|
||||
{
|
||||
reference::scatterUpdate<T, int64_t, int64_t>(
|
||||
args[0]->get_data_ptr<const T>(),
|
||||
args[1]->get_data_ptr<const int64_t>(),
|
||||
args[2]->get_data_ptr<const T>(),
|
||||
args[3]->get_data_ptr<const int64_t>(),
|
||||
out[0]->get_data_ptr<T>(),
|
||||
node.get_input_shape(0),
|
||||
node.get_input_shape(1),
|
||||
node.get_input_shape(2));
|
||||
}
|
||||
else
|
||||
{
|
||||
throw ngraph_error(
|
||||
"ScatterUpdate layer support only i32 and i64 'indices' input precision!");
|
||||
}
|
||||
|
||||
break;
|
||||
}
|
||||
|
||||
// Fused Ops are not supported in interpreter. They need to be decomposed before execution
|
||||
case OP_TYPEID::DepthToSpace:
|
||||
|
@ -37,4 +37,6 @@ NGRAPH_OP(EmbeddingSegmentsSum, op::v3)
|
||||
NGRAPH_OP(ExtractImagePatches, op::v3)
|
||||
NGRAPH_OP(ShapeOf, op::v3)
|
||||
NGRAPH_OP(NonZero, op::v3)
|
||||
NGRAPH_OP(ScatterNDUpdate, op::v3)
|
||||
NGRAPH_OP(ScatterUpdate, op::v3)
|
||||
#undef ID_SUFFIX
|
||||
|
@ -0,0 +1,63 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include "ngraph/coordinate_transform.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename dataType, typename indicesType>
|
||||
void scatterNdUpdate(const dataType* inputData,
|
||||
const indicesType* indices,
|
||||
const dataType* updates,
|
||||
dataType* outBuf,
|
||||
const Shape& dataShape,
|
||||
const Shape& indicesShape,
|
||||
const Shape& updatesShape)
|
||||
{
|
||||
size_t numSlices = 1;
|
||||
size_t sliceSize = 1;
|
||||
for (size_t i = 0; i < indicesShape.size() - 1; i++)
|
||||
{
|
||||
numSlices *= indicesShape[i];
|
||||
}
|
||||
for (size_t i = indicesShape.size() - 1; i < updatesShape.size(); i++)
|
||||
{
|
||||
sliceSize *= updatesShape[i];
|
||||
}
|
||||
|
||||
const size_t k = indicesShape.back();
|
||||
std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
|
||||
CoordinateTransform dataTransform{dataShape};
|
||||
|
||||
for (size_t i = 0; i < numSlices; i++)
|
||||
{
|
||||
Coordinate coord;
|
||||
for (size_t j = 0; j < k; j++)
|
||||
{
|
||||
coord.push_back(indices[i * k + j]);
|
||||
}
|
||||
for (size_t j = k; j < dataShape.size(); j++)
|
||||
{
|
||||
coord.push_back(0);
|
||||
}
|
||||
|
||||
const size_t startDataIdx = dataTransform.index(coord);
|
||||
for (size_t j = 0; j < sliceSize; j++)
|
||||
{
|
||||
outBuf[startDataIdx + j] = updates[i * sliceSize + j];
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace runtime
|
||||
} // namespace ngraph
|
86
ngraph/test/runtime/interpreter/reference/scatter_update.hpp
Normal file
86
ngraph/test/runtime/interpreter/reference/scatter_update.hpp
Normal file
@ -0,0 +1,86 @@
|
||||
// Copyright (C) 2020 Intel Corporation
|
||||
// SPDX-License-Identifier: Apache-2.0
|
||||
//
|
||||
|
||||
#pragma once
|
||||
|
||||
#include <string>
|
||||
#include "ngraph/coordinate_transform.hpp"
|
||||
#include "ngraph/shape.hpp"
|
||||
|
||||
using namespace ngraph;
|
||||
|
||||
namespace ngraph
|
||||
{
|
||||
namespace runtime
|
||||
{
|
||||
namespace reference
|
||||
{
|
||||
template <typename dataType, typename indicesType, typename axisType>
|
||||
void scatterUpdate(const dataType* inputData,
|
||||
const indicesType* indices,
|
||||
const dataType* updates,
|
||||
const axisType* _axis,
|
||||
dataType* outBuf,
|
||||
const Shape& dataShape,
|
||||
const Shape& indicesShape,
|
||||
const Shape& updatesShape)
|
||||
{
|
||||
int rank = static_cast<int>(dataShape.size());
|
||||
if (_axis[0] < -rank || _axis[0] > rank - 1)
|
||||
{
|
||||
std::string error =
|
||||
std::string("ScatterUpdate layer has out of bounds axis value: ") +
|
||||
std::to_string(_axis[0]);
|
||||
throw ngraph_error(error);
|
||||
}
|
||||
size_t axis = _axis[0] < 0 ? _axis[0] + rank : _axis[0];
|
||||
CoordinateTransform indicesTransform{indicesShape};
|
||||
|
||||
Shape dataShapeIter = dataShape;
|
||||
dataShapeIter.erase(dataShapeIter.begin() + axis);
|
||||
CoordinateTransform dataTransfIter{dataShapeIter};
|
||||
|
||||
CoordinateTransform updateTransform{updatesShape};
|
||||
CoordinateTransform dataTransform{dataShape};
|
||||
|
||||
std::memcpy(outBuf, inputData, sizeof(dataType) * shape_size(dataShape));
|
||||
|
||||
for (const Coordinate& indicesCoordIt : indicesTransform)
|
||||
{
|
||||
const size_t indicesIdx = indicesTransform.index(indicesCoordIt);
|
||||
|
||||
if (indices[indicesIdx] < 0)
|
||||
{
|
||||
std::string error =
|
||||
std::string("ScatterUpdate layer has negative index value: ") +
|
||||
std::to_string(indices[indicesIdx]);
|
||||
throw ngraph_error(error);
|
||||
}
|
||||
const size_t idx = static_cast<size_t>(indices[indicesIdx]);
|
||||
if (dataShape[axis] <= idx)
|
||||
{
|
||||
std::string error =
|
||||
std::string("ScatterUpdate layer has out of bounds coordinate: ") +
|
||||
std::to_string(idx) + " on 'data' input on " + std::to_string(axis) +
|
||||
"th axis";
|
||||
throw ngraph_error(error);
|
||||
}
|
||||
|
||||
for (const Coordinate& dataCoordIt : dataTransfIter)
|
||||
{
|
||||
Coordinate dataCoord = dataCoordIt;
|
||||
dataCoord.insert(dataCoord.begin() + axis, idx);
|
||||
const size_t startIndices = dataTransform.index(dataCoord);
|
||||
|
||||
auto updCoord = dataCoordIt;
|
||||
updCoord.insert(
|
||||
updCoord.begin() + axis, indicesCoordIt.begin(), indicesCoordIt.end());
|
||||
const size_t startUpd = updateTransform.index(updCoord);
|
||||
outBuf[startIndices] = updates[startUpd];
|
||||
}
|
||||
}
|
||||
}
|
||||
} // namespace reference
|
||||
} // namespace runtime
|
||||
} // namespace ngraph
|
Loading…
Reference in New Issue
Block a user