ICV: fix Scatter layers: fix validators (#541)

* ICV: fix Scatter layers: fix validators

* ICV: fix Scatter layers: enable 0D for `axis`

* Revert "ICV: fix Scatter layers: enable 0D for `axis`"

This reverts commit 82da24b989678061a585a5c7ffd7d5dab10f5edc.

* ICV: fix Scatter layers: test, fix CNNNetworkImpl
This commit is contained in:
Evgeny Latkin 2020-05-27 13:14:46 +03:00 committed by GitHub
parent 946ed119c8
commit d24132912e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 24 additions and 20 deletions

View File

@ -325,6 +325,9 @@ size_t CNNNetworkImpl::getBatchSize() const noexcept {
if (dims.size() == 3 || dims.size() == 1) {
return 1;
}
if (dims.size() == 0) {
return 0;
}
return dims.at(0);
}

View File

@ -3105,19 +3105,20 @@ void ScatterUpdateValidator::checkShapes(const CNNLayer* layer, const vector<Siz
if (inShapes[UPDATES].size() < 1)
THROW_IE_EXCEPTION << layer->name << " 'Updates' tensor rank must be >= 1";
if (!(inShapes[AXIS].size() == 1 && inShapes[AXIS][0] == 1))
THROW_IE_EXCEPTION << layer->name << " 'Axis' tensor must be 1D array of 1 element";
if (!(inShapes[AXIS].size() == 0 ||
(inShapes[AXIS].size() == 1 && inShapes[AXIS][0] == 1)))
THROW_IE_EXCEPTION << layer->name << " 'Axis' tensor must be scalar, or 1D array of 1 element";
if (inShapes[UPDATES].size() != inShapes[INDICES].size() + inShapes[DATA].size() - 1)
THROW_IE_EXCEPTION << layer->name << " Incorrect number of 'indexes' and 'updates' tensors dimension";
Precision inIdxPrecision = layer->insData[INDICES].lock()->getTensorDesc().getPrecision();
if (inIdxPrecision != Precision::FP32 && inIdxPrecision != Precision::I32)
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indices' precision. Only FP32 or I32 are supported!";
if (inIdxPrecision != Precision::I32 && inIdxPrecision != Precision::I64)
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indices' precision. Only I32 or I64 are supported!";
Precision inAxisPrecision = layer->insData[AXIS].lock()->getTensorDesc().getPrecision();
if (inAxisPrecision != Precision::FP32 && inAxisPrecision != Precision::I32)
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Axis' precision. Only FP32 or I32 are supported!";
if (inAxisPrecision != Precision::I32 && inAxisPrecision != Precision::I64)
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Axis' precision. Only I32 or I64 are supported!";
if (layer->insData[DATA].lock()->getTensorDesc().getPrecision() !=
layer->insData[UPDATES].lock()->getTensorDesc().getPrecision())
@ -3157,8 +3158,9 @@ void ScatterElementsUpdateValidator::checkShapes(const CNNLayer* layer, const ve
if (inShapes[UPDATES].size() < 1)
THROW_IE_EXCEPTION << layer->name << " 'Updates' tensor rank must be >= 1";
if (!(inShapes[AXIS].size() == 1 && inShapes[AXIS][0] == 1))
THROW_IE_EXCEPTION << layer->name << " 'Axis' tensor must be 1D array of 1 element";
if (!(inShapes[AXIS].size() == 0 ||
(inShapes[AXIS].size() == 1 && inShapes[AXIS][0] == 1)))
THROW_IE_EXCEPTION << layer->name << " 'Axis' tensor must be scalar, or 1D array of 1 element";
if (inShapes[INDICES].size() != inShapes[DATA].size())
THROW_IE_EXCEPTION << layer->name << " Incorrect number of 'indexes' tensors dimension";
@ -3167,12 +3169,12 @@ void ScatterElementsUpdateValidator::checkShapes(const CNNLayer* layer, const ve
THROW_IE_EXCEPTION << layer->name << " Incorrect number of 'updates' tensors dimension";
Precision inIdxPrecision = layer->insData[INDICES].lock()->getTensorDesc().getPrecision();
if (inIdxPrecision != Precision::FP32 && inIdxPrecision != Precision::I32 && inIdxPrecision != Precision::I64)
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indices' precision. Only FP32 or I32 or I64 are supported!";
if (inIdxPrecision != Precision::I32 && inIdxPrecision != Precision::I64)
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indices' precision. Only I32 or I64 are supported!";
Precision inAxisPrecision = layer->insData[AXIS].lock()->getTensorDesc().getPrecision();
if (inAxisPrecision != Precision::FP32 && inAxisPrecision != Precision::I32 && inIdxPrecision != Precision::I64)
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Axis' precision. Only FP32 or I32 or I64 are supported!";
if (inAxisPrecision != Precision::I32 && inIdxPrecision != Precision::I64)
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Axis' precision. Only I32 or I64 are supported!";
if (layer->insData[DATA].lock()->getTensorDesc().getPrecision() !=
layer->insData[UPDATES].lock()->getTensorDesc().getPrecision())

View File

@ -331,7 +331,6 @@ private:
<layer id="3" name="axis" type="Input">
<output>
<port id="0" precision="I32">
<dim>1</dim>
</port>
</output>
</layer>
@ -347,7 +346,6 @@ private:
__UPDATES_DIMS__
</port>
<port id="3" precision="I32">
<dim>1</dim>
</port>
</input>
<output>

View File

@ -76,8 +76,8 @@ protected:
SizeVector outputShape = inputShape; // copy
const int outputNDims = inputNDims;
SizeVector axisShape = { 1 };
const int axisNDims = 1;
SizeVector axisShape = {};
const int axisNDims = 0;
// E.g.:
// {N, C, H, W} could be shape of `input` and `output`
@ -251,8 +251,11 @@ private:
const std::vector<ie_fp16>& updatesData,
const std::vector<int32_t>& axisData) {
// yet we only support axis == 0
IE_ASSERT(axisShape.size() == 1);
IE_ASSERT(axisShape[0] == 1);
IE_ASSERT(axisShape.size() == 0 ||
axisShape.size() == 1);
if (axisShape.size() > 0) {
IE_ASSERT(axisShape[0] == 1);
}
IE_ASSERT(axisData[0] == 0);
// copy `input` to `output`
@ -400,7 +403,6 @@ private:
<layer id="3" name="axis" type="Input">
<output>
<port id="0" precision="I32">
<dim>1</dim>
</port>
</output>
</layer>
@ -416,7 +418,6 @@ private:
__UPDATES_DIMS__
</port>
<port id="3" precision="I32">
<dim>1</dim>
</port>
</input>
<output>