[CPU] ScatterUpdate ScatterElementsUpdate and ScatterNDUpdate support (#909)

* scatter_update_series_enable

* scatter_update_series_enable

* add single layer tests
This commit is contained in:
Chenhu Wang 2020-07-10 16:19:23 +08:00 committed by GitHub
parent 8e368c5e81
commit b4e3dd5c7b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
19 changed files with 1190 additions and 2 deletions

View File

@ -9,7 +9,7 @@
**Detailed description**: *ScatterUpdate* creates a copy of the first input tensor with updated elements in positions specified with `indices` input
and values specified with `updates` tensor starting from the dimension with index `axis`. For the `data` tensor of shape `[d_0, d_1, ..., d_n]`,
`indices` tensor of shape `[i_0, i_1, ..., i_k]` and `updates` tensor of shape
`[d_0, d_1, ... d_(axis - 1), i_0, i_1, ..., i_k, d_(axis + k + 1), ..., d_n]` the operation computes
`[d_0, d_1, ... d_(axis - 1), i_0, i_1, ..., i_k, d_(axis + 1), ..., d_n]` the operation computes
for each `m, n, ..., p` of the `indices` tensor indices:
```

View File

@ -44,6 +44,7 @@ set(LAYERS
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_mvn_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_resample_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_normalize_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/mkldnn_scatter_update_node.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/list.cpp
${CMAKE_CURRENT_SOURCE_DIR}/nodes/batch_to_space.cpp

View File

@ -43,6 +43,7 @@
#include <nodes/mkldnn_resample_node.h>
#include <nodes/mkldnn_normalize_node.h>
#include <nodes/mkldnn_tensoriterator_node.h>
#include <nodes/mkldnn_scatter_update_node.h>
#include <mkldnn_types.h>
#include "mkldnn_extension_utils.h"
@ -115,6 +116,9 @@ static const InferenceEngine::details::caseless_unordered_map<std::string, Type>
{ "MVN", MVN},
{ "Resample", Resample},
{ "Normalize", Normalize},
{ "ScatterUpdate", ScatterUpdate},
{ "ScatterElementsUpdate", ScatterElementsUpdate},
{ "ScatterNDUpdate", ScatterNDUpdate},
};
Type TypeFromName(const std::string type) {

View File

@ -75,7 +75,10 @@ enum Type {
Convert,
MVN,
Resample,
Normalize
Normalize,
ScatterUpdate,
ScatterElementsUpdate,
ScatterNDUpdate
};
Type TypeFromName(const std::string type);
@ -158,6 +161,12 @@ static std::string NameFromType(Type type) {
return "Resample";
case Normalize:
return "Normalize";
case ScatterUpdate:
return "ScatterUpdate";
case ScatterElementsUpdate:
return "ScatterElementsUpdate";
case ScatterNDUpdate:
return "ScatterNDUpdate";
default:
return "Unknown";
}

View File

@ -0,0 +1,500 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "mkldnn_scatter_update_node.h"
#include "desc_iterator.hpp"
#include "mkldnn_quantize_node.h"
#include "mkldnn_depthwise_node.h"
#include "mkldnn_activation_node.h"
#include <ie_layers.h>
#include <mkldnn.hpp>
#include <string>
#include <vector>
#include <mkldnn_types.h>
#include <mkldnn_extension_utils.h>
#include <ie_layers_internal.hpp>
#include "ie_parallel.hpp"
#include <algorithm>
#include "jit_generator.hpp"
#include "jit_uni_eltwise.hpp"
#include "jit_uni_depthwise.hpp"
#include "jit_uni_quantization.hpp"
#include "common/simple_copy.h"
using namespace mkldnn;
using namespace MKLDNNPlugin;
using namespace InferenceEngine;
using namespace mkldnn::impl;
using namespace mkldnn::impl::cpu;
using namespace mkldnn::impl::utils;
using namespace Xbyak;
MKLDNNScatterUpdateNode::MKLDNNScatterUpdateNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache)
: MKLDNNNode(layer, eng, cache), dataSize(0lu), indicesSize(0lu), axisSize(0lu),
dataPrec(Precision::UNSPECIFIED), indicesPrec(Precision::UNSPECIFIED), axisPrec(Precision::UNSPECIFIED) {}
void MKLDNNScatterUpdateNode::getSupportedDescriptors() {
if (!descs.empty())
return;
if ((getParentEdges().size() != 3) && (getParentEdges().size() != 4))
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' has incorrect number of input edges";
if (getChildEdges().empty())
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' has incorrect number of output edges";
if (getParentEdgeAt(DATA_ID)->getDims().ndims() < 1 ||
getParentEdgeAt(INDICES_ID)->getDims().ndims() < 1 ||
getParentEdgeAt(UPDATE_ID)->getDims().ndims() < 1) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not support scalar input";
}
Type scatterUpdateType = getType();
if (scatterUpdateType == ScatterUpdate) {
scatterUpdateMode = ScatterUpdateMode::ScatterUpdate;
axisRelaxed = true;
} else if (scatterUpdateType == ScatterElementsUpdate) {
scatterUpdateMode = ScatterUpdateMode::ScatterElementsUpdate;
axisRelaxed = true;
} else if (scatterUpdateType == ScatterNDUpdate) {
scatterUpdateMode = ScatterUpdateMode::ScatterNDUpdate;
axisRelaxed = false;
} else {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' is not supported";
}
}
void MKLDNNScatterUpdateNode::initSupportedPrimitiveDescriptors() {
if (!supportedPrimitiveDescriptors.empty())
return;
auto srcDataDim = getParentEdgeAt(DATA_ID)->getDims();
auto indicesDim = getParentEdgeAt(INDICES_ID)->getDims();
auto updateDim = getParentEdgeAt(UPDATE_ID)->getDims();
auto dstDataDim = getChildEdgeAt(0)->getDims();
size_t srcRank = srcDataDim.ndims();
size_t indicesRank = indicesDim.ndims();
size_t updateRank = updateDim.ndims();
size_t dstRank = dstDataDim.ndims();
// common check
if (srcRank != dstRank) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' should have same rank for input and outpt tensor";
} else {
for (size_t r = 0; r < srcRank; r++) {
if (srcDataDim[r] != dstDataDim[r]) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' should have same shape for input and outpt tensor." << " The input shape is "
<< srcDataDim[r] << ", while output shape is " << dstDataDim[r] << "for" << r << "th dimension";
}
}
}
// specific check
switch (scatterUpdateMode) {
case ScatterUpdateMode::ScatterUpdate: {
if (updateRank != (srcRank + indicesRank - 1)) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have matched tensor rank relationship for input, indices and update";
}
break;
}
case ScatterUpdateMode::ScatterNDUpdate: {
size_t k = indicesDim[indicesRank - 1];
if (k > srcRank) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have an correct indices' last dimension value, which should be smaller than or equal to input tensor rank";
}
SizeVector expectUpdateShape = {};
size_t tupleRank = indicesRank - 1;
for (size_t ri = 0; ri < tupleRank; ri++) {
expectUpdateShape.push_back(indicesDim[ri]);
}
for (size_t rd = k; rd < srcRank; rd++) {
expectUpdateShape.push_back(srcDataDim[rd]);
}
if (expectUpdateShape.size() != updateRank) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have matched tensor rank relationship for input, indices and update";
}
for (size_t ru = 0; ru < updateRank; ru++) {
if (updateDim[ru] != expectUpdateShape[ru]) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have matched tensor shape relationship for input, indices and update";
}
}
break;
}
case ScatterUpdateMode::ScatterElementsUpdate: {
if (srcRank != indicesRank || srcRank != updateRank) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have the same tensor rank for input, indices and update";
}
for (size_t ri = 0; ri < indicesRank; ri++) {
if (indicesDim[ri] != updateDim[ri]) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have the same tensor shape for indices and update";
}
}
break;
}
default: {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' is not supported";
}
}
indicesPrec = getCnnLayer()->insData[INDICES_ID].lock()->getPrecision();
auto indicesType = MKLDNNExtensionUtils::IEPrecisionToDataType(indicesPrec);
indicesSize = MKLDNNExtensionUtils::sizeOfDataType(indicesType);
if (indicesSize >= 8) {
indicesPrec = Precision::I64;
indicesSize = 8;
} else {
indicesPrec = Precision::I32;
indicesSize = 4;
}
indicesType = MKLDNNExtensionUtils::IEPrecisionToDataType(indicesPrec);
if (axisRelaxed) {
axisPrec = getCnnLayer()->insData[AXIS_ID].lock()->getPrecision();
auto axisType = MKLDNNExtensionUtils::IEPrecisionToDataType(axisPrec);
axisSize = MKLDNNExtensionUtils::sizeOfDataType(axisType);
if (axisSize >= 8) {
axisPrec = Precision::I64;
axisSize = 8;
} else {
axisPrec = Precision::I32;
axisSize = 4;
}
}
dataPrec = getCnnLayer()->insData[DATA_ID].lock()->getPrecision();
auto dataType = MKLDNNExtensionUtils::IEPrecisionToDataType(dataPrec);
dataSize = MKLDNNExtensionUtils::sizeOfDataType(dataType);
bool canBeInplace = getParentEdgeAt(DATA_ID)->getParent()->getChildEdges().size() == 1;
InferenceEngine::LayerConfig config;
config.dynBatchSupport = false;
if (axisRelaxed) {
config.inConfs.resize(4);
} else {
config.inConfs.resize(3);
}
config.outConfs.resize(1);
config.inConfs[DATA_ID].constant = false;
config.inConfs[INDICES_ID].constant = false;
config.inConfs[UPDATE_ID].constant = false;
config.outConfs[0].constant = false;
config.inConfs[DATA_ID].inPlace = canBeInplace ? 0 : -1;
config.inConfs[INDICES_ID].inPlace = -1;
config.inConfs[UPDATE_ID].inPlace = -1;
config.outConfs[0].inPlace = canBeInplace ? 0 : -1;
if (axisRelaxed) {
config.inConfs[AXIS_ID].constant = false;
config.inConfs[AXIS_ID].inPlace = -1;
}
auto pushDesc = [&](memory::format inFormat, memory::format idxFormat, memory::format updateFormat, memory::format outFormat) {
config.inConfs[DATA_ID].desc = MKLDNNMemoryDesc(getParentEdgeAt(DATA_ID)->getDims(), dataType, inFormat);
config.inConfs[INDICES_ID].desc = MKLDNNMemoryDesc(getParentEdgeAt(INDICES_ID)->getDims(), indicesType, idxFormat);
config.inConfs[UPDATE_ID].desc = MKLDNNMemoryDesc(getParentEdgeAt(UPDATE_ID)->getDims(), dataType, updateFormat);
if (axisRelaxed)
config.inConfs[AXIS_ID].desc = MKLDNNMemoryDesc(getParentEdgeAt(AXIS_ID)->getDims(),
MKLDNNExtensionUtils::IEPrecisionToDataType(axisPrec), memory::x);
config.outConfs[0].desc = MKLDNNMemoryDesc(getChildEdgeAt(0)->getDims(), dataType, outFormat);
supportedPrimitiveDescriptors.push_back({config, impl_desc_type::unknown, outFormat});
};
pushDesc(MKLDNNMemory::GetPlainFormat(memory::dims(getParentEdgeAt(DATA_ID)->getDims())),
MKLDNNMemory::GetPlainFormat(memory::dims(getParentEdgeAt(INDICES_ID)->getDims())),
MKLDNNMemory::GetPlainFormat(memory::dims(getParentEdgeAt(UPDATE_ID)->getDims())),
MKLDNNMemory::GetPlainFormat(memory::dims(getChildEdgeAt(0)->getDims())));
}
void MKLDNNScatterUpdateNode::createPrimitive() {
auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
auto &srcMemPtr = getParentEdgeAt(DATA_ID)->getMemoryPtr();
auto &indicesMemPtr = getParentEdgeAt(INDICES_ID)->getMemoryPtr();
auto &updateMemPtr = getParentEdgeAt(UPDATE_ID)->getMemoryPtr();
if (!dstMemPtr || !dstMemPtr->GetPrimitivePtr())
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' did not allocate destination memory";
if (!srcMemPtr || !srcMemPtr->GetPrimitivePtr())
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' did not allocate input memory";
if (!indicesMemPtr || !indicesMemPtr->GetPrimitivePtr())
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' did not allocate indices memory";
if (!updateMemPtr || !updateMemPtr->GetPrimitivePtr())
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' did not allocate update memory";
if (getSelectedPrimitiveDescriptor() == nullptr)
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' did not set preferable primitive descriptor";
}
int64_t MKLDNNScatterUpdateNode::getIndicesValue(uint8_t *indices, size_t offset) {
auto *indicesPtr = indices + offset * indicesSize;
int64_t ret = 0;
if (indicesSize == 4) {
auto *indicesPtr32 = reinterpret_cast<int32_t*>(indicesPtr);
ret = *indicesPtr32;
} else {
auto *indicesPtr64 = reinterpret_cast<int64_t*>(indicesPtr);
ret = *indicesPtr64;
}
return ret;
}
// 5D example:
// shapeND: n c d h w
// blockND: ncdhw cdhw dhw hw w 1
// index : 0 1 2 3 4 5
std::vector<size_t> getBlockND(const SizeVector& shape) {
size_t shapeRank = shape.size();
std::vector<size_t> blockND(shapeRank + 1, 1);
for (int i = shapeRank - 1; i >= 0; i--) {
blockND[i] = shape[i] * blockND[i+1];
}
return blockND;
}
void MKLDNNScatterUpdateNode::execute(mkldnn::stream strm) {
auto &dstMemPtr = getChildEdgeAt(0)->getMemoryPtr();
auto &srcMemPtr = getParentEdgeAt(DATA_ID)->getMemoryPtr();
auto &indicesMemPtr = getParentEdgeAt(INDICES_ID)->getMemoryPtr();
auto &updateMemPtr = getParentEdgeAt(UPDATE_ID)->getMemoryPtr();
uint8_t *dstPtr = reinterpret_cast<uint8_t*>(dstMemPtr->GetData()) +
dstMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding * dataSize;
uint8_t *srcPtr = reinterpret_cast<uint8_t*>(srcMemPtr->GetData()) +
srcMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding * dataSize;
uint8_t *indicesPtr = reinterpret_cast<uint8_t*>(indicesMemPtr->GetData()) +
indicesMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding * indicesSize;
uint8_t *updatePtr = reinterpret_cast<uint8_t*>(updateMemPtr->GetData()) +
updateMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding * dataSize;
SizeVector srcDataDim = getParentEdgeAt(DATA_ID)->getDesc().getDims();
SizeVector indicesDim = getParentEdgeAt(INDICES_ID)->getDesc().getDims();
size_t srcRank = srcDataDim.size();
int axis = 0;
if (axisRelaxed) {
auto &axisMemPtr = getParentEdgeAt(AXIS_ID)->getMemoryPtr();
uint8_t *axisPtr = reinterpret_cast<uint8_t*>(axisMemPtr->GetData()) +
axisMemPtr->GetDescriptor().data.layout_desc.blocking.offset_padding * axisSize;
if (axisSize == 4) {
auto *axisPtr32 = reinterpret_cast<int32_t*>(axisPtr);
axis = *axisPtr32;
} else {
auto *axisPtr64 = reinterpret_cast<int64_t*>(axisPtr);
axis = *axisPtr64;
}
if (axis >= static_cast<int>(srcRank) || axis < (static_cast<int>(srcRank) * - 1)) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' should have axis value in range [-r, r - 1], where r is the rank of input data";
}
axis = axis < 0 ? (axis + srcRank) : axis;
size_t srcDimAxis = srcDataDim[axis];
std::vector<size_t> indicesBlockND = getBlockND(indicesDim);
parallel_nt(0, [&](const int ithr, const int nthr) {
size_t start = 0, end = 0;
splitter(indicesBlockND[0], nthr, ithr, start, end);
for (int i = start; i < end; i++) {
int64_t idxValue = getIndicesValue(indicesPtr, i);
if (idxValue >= static_cast<int64_t>(srcDimAxis) || idxValue < 0) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' have indices value that points to non-existing output tensor element";
}
}
});
if (scatterUpdateMode == ScatterUpdateMode::ScatterUpdate) {
SizeVector indicesDim = getParentEdgeAt(INDICES_ID)->getDesc().getDims();
SizeVector updateDim = getParentEdgeAt(UPDATE_ID)->getDesc().getDims();
size_t indicesRank = indicesDim.size();
size_t updateRank = updateDim.size();
SizeVector expectUpdateShape = {};
for (size_t rs = 0; rs < srcRank; rs++) {
if (rs != axis) {
expectUpdateShape.push_back(srcDataDim[rs]);
} else {
for (size_t ri = 0; ri < indicesRank; ri++) {
expectUpdateShape.push_back(indicesDim[ri]);
}
}
}
for (size_t ru = 0; ru < updateRank; ru++) {
if (updateDim[ru] != expectUpdateShape[ru]) {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' do not have matched tensor shape relationship for input, indices and update";
}
}
}
}
if (srcPtr != dstPtr) {
std::vector<size_t> srcBlockND = getBlockND(srcDataDim);
parallel_nt(0, [&](const int ithr, const int nthr) {
size_t start = 0, end = 0;
splitter(srcBlockND[0], nthr, ithr, start, end);
size_t size = (end - start) * dataSize;
start *= dataSize;
simple_copy(dstPtr + start, size, srcPtr + start, size);
});
}
switch (scatterUpdateMode) {
case ScatterUpdateMode::ScatterUpdate: {
scatterUpdate(indicesPtr, updatePtr, axis, dstPtr);
break;
}
case ScatterUpdateMode::ScatterNDUpdate: {
scatterNDUpdate(indicesPtr, updatePtr, dstPtr);
break;
}
case ScatterUpdateMode::ScatterElementsUpdate: {
scatterElementsUpdate(indicesPtr, updatePtr, axis, dstPtr);
break;
}
default: {
THROW_IE_EXCEPTION << "'" << getType() << "'" << " layer with name '" << getName()
<< "' is not supported";
}
}
}
// For the data tensor of shape [d_0, d_1, ..., d_n],
// and indices tensor of shape [i_0, i_1, ..., i_k].
// Updates tensor shape should be [d_0, d_1, ... d_(axis - 1), i_0, i_1, ..., i_k, d_(axis + 1), ..., d_n].
void MKLDNNScatterUpdateNode::scatterUpdate(uint8_t *indices, uint8_t *update, int axis, uint8_t *dstData) {
SizeVector srcDataDim = getParentEdgeAt(DATA_ID)->getDesc().getDims();
SizeVector indicesDim = getParentEdgeAt(INDICES_ID)->getDesc().getDims();
SizeVector updateDim = getParentEdgeAt(UPDATE_ID)->getDesc().getDims();
size_t indicesRank = indicesDim.size();
std::vector<size_t> srcBlockND = getBlockND(srcDataDim);
std::vector<size_t> updateBlockND = getBlockND(updateDim);
const size_t mulIdentity = 1;
size_t idxLength = mulIdentity;
for (size_t ri = 0; ri < indicesRank; ri++) {
idxLength *= indicesDim[ri];
}
size_t batchToUpdate = mulIdentity;
for (size_t x = 0; x < axis; x++) {
batchToUpdate *= srcDataDim[x];
}
// blockToUpdate is srcBlockND[axis + 1], also is updateBlockND[axis + indicesRank]
size_t blockToUpdate = srcBlockND[axis + 1];
size_t blockToUpdateSize = blockToUpdate * dataSize;
parallel_for2d(batchToUpdate, idxLength, [&](size_t b, size_t idx) {
int64_t idxValue = getIndicesValue(indices, idx);
uint8_t *dstEntry = dstData + (b * srcBlockND[axis] + idxValue * blockToUpdate) * dataSize;
uint8_t *updateEntry = update + (b * updateBlockND[axis] + idx * blockToUpdate) * dataSize;
simple_copy(dstEntry, blockToUpdateSize, updateEntry, blockToUpdateSize);
});
}
// indices is a (q-1)-dimension tensor of k-tuple,
// k is indices.shape[-1] and should not be greater than rank of input, q is rank of indicies.
// updates is a (q-1)-dimension tensor of replacement-slice-values
void MKLDNNScatterUpdateNode::scatterNDUpdate(uint8_t *indices, uint8_t *update, uint8_t *dstData) {
SizeVector srcDataDim = getParentEdgeAt(DATA_ID)->getDesc().getDims();
SizeVector indicesDim = getParentEdgeAt(INDICES_ID)->getDesc().getDims();
size_t indicesRank = indicesDim.size();
std::vector<size_t> srcBlockND = getBlockND(srcDataDim);
size_t k = indicesDim[indicesRank - 1];
size_t idxTupleNum = 1;
for (size_t ri = 0; ri < indicesRank - 1; ri++) {
idxTupleNum *= indicesDim[ri];
}
size_t sizeToUpdate = srcBlockND[k] * dataSize;
parallel_for(idxTupleNum, [&](size_t tupleIdx) {
size_t indicesOffset = tupleIdx * k;
size_t dstOffset = 0;
for (int i = 0; i < k; i++) {
size_t idxValue = getIndicesValue(indices, indicesOffset + i);
dstOffset += idxValue * srcBlockND[i + 1];
}
dstOffset *= dataSize;
size_t updateOffset = tupleIdx * sizeToUpdate;
simple_copy(dstData + dstOffset, sizeToUpdate, update + updateOffset, sizeToUpdate);
});
}
// output[indices[i][j][k]][j][k] = updates[i][j][k] if axis = 0,
// output[i][indices[i][j][k]][k] = updates[i][j][k] if axis = 1,
// output[i][j][indices[i][j][k]] = updates[i][j][k] if axis = 2.
void MKLDNNScatterUpdateNode::scatterElementsUpdate(uint8_t *indices, uint8_t *update, int axis, uint8_t *dstData) {
SizeVector srcDataDim = getParentEdgeAt(DATA_ID)->getDesc().getDims();
SizeVector updateDim = getParentEdgeAt(UPDATE_ID)->getDesc().getDims();
SizeVector indicesDim = getParentEdgeAt(INDICES_ID)->getDesc().getDims();
size_t srcRank = srcDataDim.size();
size_t updateRank = updateDim.size();
std::vector<size_t> srcBlockND = getBlockND(srcDataDim);
std::vector<size_t> updateBlockND = getBlockND(updateDim);
parallel_nt(0, [&](const int ithr, const int nthr) {
int j;
size_t i, dst_idx = 0, start = 0, end = 0;
SizeVector tensorItr(updateRank, 0);
splitter(updateBlockND[0], nthr, ithr, start, end);
for (j = updateRank - 1, i = start; j >= 0; j--) {
tensorItr[j] = i % updateDim[j];
i /= updateDim[j];
}
for (i = 0; i < static_cast<size_t>(axis); ++i)
dst_idx += tensorItr[i] * srcBlockND[i + 1];
for (i++; i < updateRank; ++i)
dst_idx += tensorItr[i] * srcBlockND[i + 1];
for (size_t iwork = start; iwork < end; iwork++) {
int64_t idxValue = getIndicesValue(indices, iwork);
if (idxValue < srcDataDim[axis])
simple_copy(dstData + dataSize * (dst_idx + idxValue * srcBlockND[axis + 1]), dataSize,
update + iwork * dataSize, dataSize);
for (j = updateRank - 1; j >= 0; j--) {
tensorItr[j]++;
if (tensorItr[j] < updateDim[j]) {
if (j != static_cast<size_t>(axis))
dst_idx += srcBlockND[j + 1];
break;
} else {
tensorItr[j] = 0;
for (dst_idx = 0, i = 0; i < static_cast<size_t>(axis); ++i)
dst_idx += tensorItr[i] * srcBlockND[i + 1];
for (i++; i < updateRank; ++i)
dst_idx += tensorItr[i] * srcBlockND[i + 1];
}
}
}
});
}
bool MKLDNNScatterUpdateNode::created() const {
return getType() == ScatterUpdate || getType() == ScatterElementsUpdate || getType() == ScatterNDUpdate;
}
REG_MKLDNN_PRIM_FOR(MKLDNNScatterUpdateNode, ScatterUpdate);
REG_MKLDNN_PRIM_FOR(MKLDNNScatterUpdateNode, ScatterElementsUpdate);
REG_MKLDNN_PRIM_FOR(MKLDNNScatterUpdateNode, ScatterNDUpdate);

View File

@ -0,0 +1,53 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <ie_common.h>
#include <mkldnn_node.h>
#include <string>
#include <memory>
#include <vector>
namespace MKLDNNPlugin {
enum class ScatterUpdateMode {
ScatterUpdate,
ScatterNDUpdate,
ScatterElementsUpdate
};
class MKLDNNScatterUpdateNode : public MKLDNNNode {
public:
MKLDNNScatterUpdateNode(const InferenceEngine::CNNLayerPtr& layer, const mkldnn::engine& eng, MKLDNNWeightsSharing::Ptr &cache);
~MKLDNNScatterUpdateNode() override = default;
void getSupportedDescriptors() override;
void initSupportedPrimitiveDescriptors() override;
void createPrimitive() override;
bool created() const override;
void execute(mkldnn::stream strm) override;
bool canBeInPlace() const override {
return false;
}
private:
void scatterUpdate(uint8_t *indicesPtr, uint8_t *updatePtr, int axis, uint8_t *dstDataPtr);
void scatterNDUpdate(uint8_t *indicesPtr, uint8_t *updatePtr, uint8_t *dstDataPtr);
void scatterElementsUpdate(uint8_t *indicesPtr, uint8_t *updatePtr, int axis, uint8_t *dstDataPtr);
inline int64_t getIndicesValue(uint8_t *indices, size_t offset);
ScatterUpdateMode scatterUpdateMode = ScatterUpdateMode::ScatterUpdate;
const size_t DATA_ID = 0;
const size_t INDICES_ID = 1;
const size_t UPDATE_ID = 2;
const size_t AXIS_ID = 3;
// if axis can be set other than default 0.
bool axisRelaxed = false;
size_t dataSize, indicesSize, axisSize;
InferenceEngine::Precision dataPrec, indicesPrec, axisPrec;
};
} // namespace MKLDNNPlugin

View File

@ -0,0 +1,42 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <ngraph/opsets/opset3.hpp>
#include "single_layer_tests/scatter_ND_update.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
using namespace ngraph::opset3;
namespace {
const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
};
const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
};
// map<inputShape map<indicesShape, indicesValue>>
// updateShape is gotten from inputShape and indicesShape
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<size_t>>> sliceSelectInShape {
{{10, 9, 9, 11}, {{{4, 1}, {1, 3, 5, 7}}, {{1, 2}, {4, 6}}, {{2, 3}, {0, 1, 1, 2, 2, 2}}, {{1, 4}, {5, 5, 4, 9}}}},
{{10, 9, 10, 9, 10}, {{{2, 2, 1}, {5, 6, 2, 8}}, {{2, 3}, {0, 4, 6, 5, 7, 1}}}},
};
const auto ScatterNDUpdateCases = ::testing::Combine(
::testing::ValuesIn(ScatterNDUpdateLayerTest::combineShapes(sliceSelectInShape)),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
// open after ops support in ngraph merged
// INSTANTIATE_TEST_CASE_P(ScatterNDUpdate, ScatterNDUpdateLayerTest, ScatterNDUpdateCases, ScatterNDUpdateLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,48 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <ngraph/opsets/opset3.hpp>
#include "single_layer_tests/scatter_elements_update.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
using namespace ngraph::opset3;
namespace {
// map<inputShape, map<indicesShape, axis>>
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>> axesShapeInShape {
{{10, 12, 15}, {{{1, 2, 4}, {0, 1, 2}}, {{2, 2, 2}, {-1, -2, -3}}}},
{{15, 9, 8, 12}, {{{1, 2, 2, 2}, {0, 1, 2, 3}}, {{1, 2, 1, 4}, {-1, -2, -3, -4}}}},
{{9, 9, 8, 8, 11, 10}, {{{1, 2, 1, 2, 1, 2}, {5, -3}}}},
};
// index value should not be random data
const std::vector<std::vector<size_t>> idxValue = {
{1, 0, 4, 6, 2, 3, 7, 5}
};
const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
};
const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
};
const auto ScatterEltUpdateCases = ::testing::Combine(
::testing::ValuesIn(ScatterElementsUpdateLayerTest::combineShapes(axesShapeInShape)),
::testing::ValuesIn(idxValue),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
INSTANTIATE_TEST_CASE_P(ScatterEltsUpdate, ScatterElementsUpdateLayerTest,
ScatterEltUpdateCases, ScatterElementsUpdateLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,47 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <vector>
#include <ngraph/opsets/opset3.hpp>
#include "single_layer_tests/scatter_update.hpp"
#include "common_test_utils/test_constants.hpp"
using namespace LayerTestsDefinitions;
using namespace ngraph::opset3;
namespace {
const std::vector<InferenceEngine::Precision> inputPrecisions = {
InferenceEngine::Precision::FP32,
InferenceEngine::Precision::FP16,
InferenceEngine::Precision::I32,
};
const std::vector<InferenceEngine::Precision> idxPrecisions = {
InferenceEngine::Precision::I32,
InferenceEngine::Precision::I64,
};
// map<inputShape, map<indicesShape, axis>>
std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>> axesShapeInShape {
{{10, 16, 12, 15}, {{{2, 4}, {0, 1, 2, 3}}, {{8}, {-1, -2, -3, -4}}}},
{{10, 9, 10, 9, 10}, {{{8}, {-3, -1, 0, 2, 4}}, {{4, 2}, {-2, 2}}}},
};
//indices should not be random value
const std::vector<std::vector<size_t>> idxValue = {
{0, 2, 4, 6, 1, 3, 5, 7}
};
const auto ScatterUpdateCase = ::testing::Combine(
::testing::ValuesIn(ScatterUpdateLayerTest::combineShapes(axesShapeInShape)),
::testing::ValuesIn(idxValue),
::testing::ValuesIn(inputPrecisions),
::testing::ValuesIn(idxPrecisions),
::testing::Values(CommonTestUtils::DEVICE_CPU)
);
// open after ngraph reference implementation merged
// INSTANTIATE_TEST_CASE_P(ScatterUpdate, ScatterUpdateLayerTest, ScatterUpdateCase, ScatterUpdateLayerTest::getTestCaseName);
} // namespace

View File

@ -0,0 +1,37 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "functional_test_utils/layer_test_utils.hpp"
namespace LayerTestsDefinitions {
using sliceSelcetInShape = std::tuple<
std::vector<size_t>, // input shape
std::vector<size_t>, // indices shape
std::vector<size_t>, // indices value
std::vector<size_t>>; // update shape
using scatterNDUpdateParamsTuple = typename std::tuple<
sliceSelcetInShape, // Input description
InferenceEngine::Precision, // Network precision
InferenceEngine::Precision, // indices precision
std::string>; // Device name
class ScatterNDUpdateLayerTest : public testing::WithParamInterface<scatterNDUpdateParamsTuple>,
public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<scatterNDUpdateParamsTuple> &obj);
static std::vector<sliceSelcetInShape> combineShapes(
const std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<size_t>>>& inputShapes);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,37 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "functional_test_utils/layer_test_utils.hpp"
namespace LayerTestsDefinitions {
using axisShapeInShape = std::tuple<
std::vector<size_t>, // input shape
std::vector<size_t>, // update shape
int>; // axis
using scatterElementsUpdateParamsTuple = typename std::tuple<
axisShapeInShape, // shape description
std::vector<size_t>, // indices value
InferenceEngine::Precision, // Network precision
InferenceEngine::Precision, // indices precision
std::string>; // Device name
class ScatterElementsUpdateLayerTest : public testing::WithParamInterface<scatterElementsUpdateParamsTuple>,
public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<scatterElementsUpdateParamsTuple> &obj);
static std::vector<axisShapeInShape> combineShapes(
const std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>>& inputShapes);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,38 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#pragma once
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include "functional_test_utils/layer_test_utils.hpp"
namespace LayerTestsDefinitions {
using axisShapeInShape = std::tuple<
std::vector<size_t>, // input shape
std::vector<size_t>, // indices shape
std::vector<size_t>, // update shape
int>; // axis
using scatterUpdateParamsTuple = typename std::tuple<
axisShapeInShape, // shape description
std::vector<size_t>, // indices value
InferenceEngine::Precision, // input precision
InferenceEngine::Precision, // indices precision
std::string>; // Device name
class ScatterUpdateLayerTest : public testing::WithParamInterface<scatterUpdateParamsTuple>,
public LayerTestsUtils::LayerTestsCommon {
public:
static std::string getTestCaseName(const testing::TestParamInfo<scatterUpdateParamsTuple> &obj);
static std::vector<axisShapeInShape> combineShapes(
const std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>>& inputShapes);
protected:
void SetUp() override;
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,94 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include <ie_core.hpp>
#include <ngraph_functions/builders.hpp>
#include "functional_test_utils/blob_utils.hpp"
#include "functional_test_utils/precision_utils.hpp"
#include "common_test_utils/common_utils.hpp"
#include "single_layer_tests/scatter_ND_update.hpp"
using namespace ngraph::opset3;
namespace LayerTestsDefinitions {
std::string ScatterNDUpdateLayerTest::getTestCaseName(const testing::TestParamInfo<scatterNDUpdateParamsTuple> &obj) {
sliceSelcetInShape shapeDescript;
std::vector<size_t> inShape;
std::vector<size_t> indicesShape;
std::vector<size_t> indicesValue;
std::vector<size_t> updateShape;
InferenceEngine::Precision inputPrecision;
InferenceEngine::Precision indicesPrecision;
std::string targetName;
std::tie(shapeDescript, inputPrecision, indicesPrecision, targetName) = obj.param;
std::tie(inShape, indicesShape, indicesValue, updateShape) = shapeDescript;
std::ostringstream result;
result << "InputShape=" << CommonTestUtils::vec2str(inShape) << "_";
result << "IndicesShape=" << CommonTestUtils::vec2str(indicesShape) << "_";
result << "UpdateShape=" << CommonTestUtils::vec2str(updateShape) << "_";
result << "inPrc=" << inputPrecision.name() << "_";
result << "idxPrc=" << indicesPrecision.name() << "_";
result << "targetDevice=" << targetName << "_";
return result.str();
}
std::vector<sliceSelcetInShape> ScatterNDUpdateLayerTest::combineShapes(
const std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<size_t>>>& inputShapes) {
std::vector<sliceSelcetInShape> resVec;
for (auto& inputShape : inputShapes) {
for (auto& item : inputShape.second) {
auto indiceShape = item.first;
size_t indicesRank = indiceShape.size();
std::vector<size_t> updateShape;
for (size_t i = 0; i < indicesRank - 1; i++) {
updateShape.push_back(indiceShape[i]);
}
auto srcShape = inputShape.first;
for (size_t j = indiceShape[indicesRank - 1]; j < srcShape.size(); j++) {
updateShape.push_back(srcShape[j]);
}
resVec.push_back(std::make_tuple(srcShape, indiceShape, item.second, updateShape));
}
}
return resVec;
}
void ScatterNDUpdateLayerTest::SetUp() {
sliceSelcetInShape shapeDescript;
InferenceEngine::SizeVector inShape;
InferenceEngine::SizeVector indicesShape;
InferenceEngine::SizeVector indicesValue;
InferenceEngine::SizeVector updateShape;
InferenceEngine::Precision inputPrecision;
InferenceEngine::Precision indicesPrecision;
std::tie(shapeDescript, inputPrecision, indicesPrecision, targetDevice) = this->GetParam();
std::tie(inShape, indicesShape, indicesValue, updateShape) = shapeDescript;
auto inPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inputPrecision);
auto idxPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(indicesPrecision);
ngraph::ParameterVector paramVector;
auto inputParams = std::make_shared<ngraph::opset1::Parameter>(inPrc, ngraph::Shape(inShape));
paramVector.push_back(inputParams);
auto updateParams = std::make_shared<ngraph::opset1::Parameter>(inPrc, ngraph::Shape(updateShape));
paramVector.push_back(updateParams);
auto paramVectorOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(paramVector));
auto s2d = ngraph::builder::makeScatterNDUpdate(paramVectorOuts[0], idxPrc, indicesShape, indicesValue, paramVectorOuts[1]);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(s2d)};
function = std::make_shared<ngraph::Function>(results, paramVector, "ScatterNDUpdate");
}
TEST_P(ScatterNDUpdateLayerTest, CompareWithRefs) {
Run();
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,82 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include <ie_core.hpp>
#include <ngraph_functions/builders.hpp>
#include "functional_test_utils/blob_utils.hpp"
#include "functional_test_utils/precision_utils.hpp"
#include "common_test_utils/common_utils.hpp"
#include "single_layer_tests/scatter_elements_update.hpp"
using namespace ngraph::opset3;
namespace LayerTestsDefinitions {
std::string ScatterElementsUpdateLayerTest::getTestCaseName(const testing::TestParamInfo<scatterElementsUpdateParamsTuple> &obj) {
axisShapeInShape shapeDescript;
InferenceEngine::SizeVector indicesValue;
InferenceEngine::Precision inputPrecision;
InferenceEngine::Precision indicesPrecision;
std::string targetName;
std::tie(shapeDescript, indicesValue, inputPrecision, indicesPrecision, targetName) = obj.param;
std::ostringstream result;
result << "InputShape=" << CommonTestUtils::vec2str(std::get<0>(shapeDescript)) << "_";
result << "IndicesShape=" << CommonTestUtils::vec2str(std::get<1>(shapeDescript)) << "_";
result << "Axis=" << std::get<2>(shapeDescript) << "_";
result << "inPrc=" << inputPrecision.name() << "_";
result << "idxPrc=" << indicesPrecision.name() << "_";
result << "targetDevice=" << targetName << "_";
return result.str();
}
std::vector<axisShapeInShape> ScatterElementsUpdateLayerTest::combineShapes(
const std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>>& inputShapes) {
std::vector<axisShapeInShape> resVec;
for (auto& inputShape : inputShapes) {
for (auto& item : inputShape.second) {
for (auto& elt : item.second) {
resVec.push_back(std::make_tuple(inputShape.first, item.first, elt));
}
}
}
return resVec;
}
void ScatterElementsUpdateLayerTest::SetUp() {
InferenceEngine::SizeVector inShape;
InferenceEngine::SizeVector indicesShape;
int axis;
axisShapeInShape shapeDescript;
InferenceEngine::SizeVector indicesValue;
InferenceEngine::Precision inputPrecision;
InferenceEngine::Precision indicesPrecision;
std::tie(shapeDescript, indicesValue, inputPrecision, indicesPrecision, targetDevice) = this->GetParam();
std::tie(inShape, indicesShape, axis) = shapeDescript;
auto inPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inputPrecision);
auto idxPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(indicesPrecision);
ngraph::ParameterVector paramVector;
auto inputParams = std::make_shared<ngraph::opset1::Parameter>(inPrc, ngraph::Shape(inShape));
paramVector.push_back(inputParams);
auto updateParams = std::make_shared<ngraph::opset1::Parameter>(inPrc, ngraph::Shape(indicesShape));
paramVector.push_back(updateParams);
auto paramVectorOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(paramVector));
auto s2d = ngraph::builder::makeScatterElementsUpdate(paramVectorOuts[0], idxPrc, indicesShape, indicesValue, paramVectorOuts[1], axis);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(s2d)};
function = std::make_shared<ngraph::Function>(results, paramVector, "ScatterElementsUpdate");
}
TEST_P(ScatterElementsUpdateLayerTest, CompareWithRefs) {
Run();
};
} // namespace LayerTestsDefinitions

View File

@ -0,0 +1,105 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include <algorithm>
#include <functional>
#include <memory>
#include <string>
#include <tuple>
#include <vector>
#include <ie_core.hpp>
#include <ngraph_functions/builders.hpp>
#include "functional_test_utils/blob_utils.hpp"
#include "functional_test_utils/precision_utils.hpp"
#include "common_test_utils/common_utils.hpp"
#include "single_layer_tests/scatter_update.hpp"
using namespace ngraph::opset3;
namespace LayerTestsDefinitions {
std::string ScatterUpdateLayerTest::getTestCaseName(const testing::TestParamInfo<scatterUpdateParamsTuple> &obj) {
axisShapeInShape shapeDescript;
std::vector<size_t> inShape;
std::vector<size_t> indicesShape;
std::vector<size_t> updateShape;
int axis;
std::vector<size_t> indicesValue;
InferenceEngine::Precision inputPrecision;
InferenceEngine::Precision indicesPrecision;
std::string targetName;
std::tie(shapeDescript, indicesValue, inputPrecision, indicesPrecision, targetName) = obj.param;
std::tie(inShape, indicesShape, updateShape, axis) = shapeDescript;
std::ostringstream result;
result << "InputShape=" << CommonTestUtils::vec2str(inShape) << "_";
result << "IndicesShape=" << CommonTestUtils::vec2str(indicesShape) << "_";
result << "IndicesValue=" << CommonTestUtils::vec2str(indicesValue) << "_";
result << "UpdateShape=" << CommonTestUtils::vec2str(updateShape) << "_";
result << "Axis=" << axis << "_";
result << "inPrc=" << inputPrecision.name() << "_";
result << "idxPrc=" << indicesPrecision.name() << "_";
result << "targetDevice=" << targetName << "_";
return result.str();
}
std::vector<axisShapeInShape> ScatterUpdateLayerTest::combineShapes(
const std::map<std::vector<size_t>, std::map<std::vector<size_t>, std::vector<int>>>& inputShapes) {
std::vector<axisShapeInShape> resVec;
for (auto& inputShape : inputShapes) {
auto srcShape = inputShape.first;
auto srcRank = srcShape.size();
for (auto& item : inputShape.second) {
auto indicesShape = item.first;
auto indicesRank = indicesShape.size();
for (auto& axis : item.second) {
auto axisP = axis < 0 ? axis + srcRank : axis;
std::vector<size_t> updateShape;
for (size_t rs = 0; rs < srcRank; rs++) {
if (rs != axisP) {
updateShape.push_back(srcShape[rs]);
} else {
for (size_t ri = 0; ri < indicesRank; ri++) {
updateShape.push_back(indicesShape[ri]);
}
}
}
resVec.push_back(std::make_tuple(srcShape, indicesShape, updateShape, axis));
}
}
}
return resVec;
}
void ScatterUpdateLayerTest::SetUp() {
axisShapeInShape shapeDescript;
InferenceEngine::SizeVector inShape;
InferenceEngine::SizeVector indicesShape;
InferenceEngine::SizeVector updateShape;
int axis;
InferenceEngine::SizeVector indicesValue;
InferenceEngine::Precision inputPrecision;
InferenceEngine::Precision indicesPrecision;
std::tie(shapeDescript, indicesValue, inputPrecision, indicesPrecision, targetDevice) = this->GetParam();
std::tie(inShape, indicesShape, updateShape, axis) = shapeDescript;
auto inPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(inputPrecision);
auto idxPrc = FuncTestUtils::PrecisionUtils::convertIE2nGraphPrc(indicesPrecision);
ngraph::ParameterVector paramVector;
auto inputParams = std::make_shared<ngraph::opset1::Parameter>(inPrc, ngraph::Shape(inShape));
paramVector.push_back(inputParams);
auto updateParams = std::make_shared<ngraph::opset1::Parameter>(inPrc, ngraph::Shape(updateShape));
paramVector.push_back(updateParams);
auto paramVectorOuts = ngraph::helpers::convert2OutputVector(ngraph::helpers::castOps2Nodes<ngraph::op::Parameter>(paramVector));
auto s2d = ngraph::builder::makeScatterUpdate(paramVectorOuts[0], idxPrc, indicesShape, indicesValue, paramVectorOuts[1], axis);
ngraph::ResultVector results{std::make_shared<ngraph::opset1::Result>(s2d)};
function = std::make_shared<ngraph::Function>(results, paramVector, "ScatterUpdate");
}
TEST_P(ScatterUpdateLayerTest, CompareWithRefs) {
Run();
};
} // namespace LayerTestsDefinitions

View File

@ -270,5 +270,25 @@ std::shared_ptr<Node> makePooling(const ngraph::Output<Node> &in,
bool excludePad,
const ngraph::helpers::PoolingTypes &poolType);
std::shared_ptr<ngraph::Node> makeScatterUpdate(const ngraph::Output<Node> &in,
const element::Type& indicesType,
const std::vector<size_t>& indicesShape,
const std::vector<size_t>& indices,
const ngraph::Output<Node> &update,
std::size_t axis);
std::shared_ptr<ngraph::Node> makeScatterElementsUpdate(const ngraph::Output<Node> &in,
const element::Type& indicesType,
const std::vector<size_t>& indicesShape,
const std::vector<size_t>& indices,
const ngraph::Output<Node> &update,
int axis);
std::shared_ptr<ngraph::Node> makeScatterNDUpdate(const ngraph::Output<Node> &in,
const element::Type& indicesType,
const std::vector<size_t>& indicesShape,
const std::vector<size_t>& indices,
const ngraph::Output<Node> &update);
} // namespace builder
} // namespace ngraph

View File

@ -0,0 +1,23 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph_functions/builders.hpp"
namespace ngraph {
namespace builder {
std::shared_ptr<ngraph::Node> makeScatterNDUpdate(const ngraph::Output<Node> &in,
const element::Type& indicesType,
const std::vector<size_t>& indicesShape,
const std::vector<size_t>& indices,
const ngraph::Output<Node> &update) {
auto indicesNode = std::make_shared<ngraph::opset1::Constant>(indicesType, indicesShape, indices);
// blocked by ngraph merge
// auto dtsNode = std::make_shared<ngraph::opset3::ScatterNDUpdate>(in, indicesNode, update);
// return dtsNode;
return nullptr;
}
} // namespace builder
} // namespace ngraph

View File

@ -0,0 +1,24 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph_functions/builders.hpp"
namespace ngraph {
namespace builder {
std::shared_ptr<ngraph::Node> makeScatterElementsUpdate(const ngraph::Output<Node> &in,
const element::Type& indicesType,
const std::vector<size_t>& indicesShape,
const std::vector<size_t>& indices,
const ngraph::Output<Node> &update,
int axis) {
auto indicesNode = std::make_shared<ngraph::opset1::Constant>(indicesType, indicesShape, indices);
auto axis_node = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i32, ngraph::Shape{},
std::vector<int>{axis});
auto dtsNode = std::make_shared<ngraph::opset3::ScatterElementsUpdate>(in, indicesNode, update, axis_node);
return dtsNode;
}
} // namespace builder
} // namespace ngraph

View File

@ -0,0 +1,24 @@
// Copyright (C) 2020 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//
#include "ngraph_functions/builders.hpp"
namespace ngraph {
namespace builder {
std::shared_ptr<ngraph::Node> makeScatterUpdate(const ngraph::Output<Node> &in,
const element::Type& indicesType,
const std::vector<size_t> &indicesShape,
const std::vector<size_t> &indices,
const ngraph::Output<Node> &update,
std::size_t axis) {
auto indicesNode = std::make_shared<ngraph::opset1::Constant>(indicesType, indicesShape, indices);
auto axis_node = std::make_shared<ngraph::opset1::Constant>(ngraph::element::Type_t::i64, ngraph::Shape{},
std::vector<uint64_t>{axis});
auto dtsNode = std::make_shared<ngraph::opset3::ScatterUpdate>(in, indicesNode, update, axis_node);
return dtsNode;
}
} // namespace builder
} // namespace ngraph