ICV: fix Scatter layers: fix validators (#541)
* ICV: fix Scatter layers: fix validators * ICV: fix Scatter layers: enable 0D for `axis` * Revert "ICV: fix Scatter layers: enable 0D for `axis`" This reverts commit 82da24b989678061a585a5c7ffd7d5dab10f5edc. * ICV: fix Scatter layers: test, fix CNNNetworkImpl
This commit is contained in:
parent
946ed119c8
commit
d24132912e
@ -325,6 +325,9 @@ size_t CNNNetworkImpl::getBatchSize() const noexcept {
|
||||
if (dims.size() == 3 || dims.size() == 1) {
|
||||
return 1;
|
||||
}
|
||||
if (dims.size() == 0) {
|
||||
return 0;
|
||||
}
|
||||
return dims.at(0);
|
||||
}
|
||||
|
||||
|
@ -3105,19 +3105,20 @@ void ScatterUpdateValidator::checkShapes(const CNNLayer* layer, const vector<Siz
|
||||
if (inShapes[UPDATES].size() < 1)
|
||||
THROW_IE_EXCEPTION << layer->name << " 'Updates' tensor rank must be >= 1";
|
||||
|
||||
if (!(inShapes[AXIS].size() == 1 && inShapes[AXIS][0] == 1))
|
||||
THROW_IE_EXCEPTION << layer->name << " 'Axis' tensor must be 1D array of 1 element";
|
||||
if (!(inShapes[AXIS].size() == 0 ||
|
||||
(inShapes[AXIS].size() == 1 && inShapes[AXIS][0] == 1)))
|
||||
THROW_IE_EXCEPTION << layer->name << " 'Axis' tensor must be scalar, or 1D array of 1 element";
|
||||
|
||||
if (inShapes[UPDATES].size() != inShapes[INDICES].size() + inShapes[DATA].size() - 1)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect number of 'indexes' and 'updates' tensors dimension";
|
||||
|
||||
Precision inIdxPrecision = layer->insData[INDICES].lock()->getTensorDesc().getPrecision();
|
||||
if (inIdxPrecision != Precision::FP32 && inIdxPrecision != Precision::I32)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indices' precision. Only FP32 or I32 are supported!";
|
||||
if (inIdxPrecision != Precision::I32 && inIdxPrecision != Precision::I64)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indices' precision. Only I32 or I64 are supported!";
|
||||
|
||||
Precision inAxisPrecision = layer->insData[AXIS].lock()->getTensorDesc().getPrecision();
|
||||
if (inAxisPrecision != Precision::FP32 && inAxisPrecision != Precision::I32)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Axis' precision. Only FP32 or I32 are supported!";
|
||||
if (inAxisPrecision != Precision::I32 && inAxisPrecision != Precision::I64)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Axis' precision. Only I32 or I64 are supported!";
|
||||
|
||||
if (layer->insData[DATA].lock()->getTensorDesc().getPrecision() !=
|
||||
layer->insData[UPDATES].lock()->getTensorDesc().getPrecision())
|
||||
@ -3157,8 +3158,9 @@ void ScatterElementsUpdateValidator::checkShapes(const CNNLayer* layer, const ve
|
||||
if (inShapes[UPDATES].size() < 1)
|
||||
THROW_IE_EXCEPTION << layer->name << " 'Updates' tensor rank must be >= 1";
|
||||
|
||||
if (!(inShapes[AXIS].size() == 1 && inShapes[AXIS][0] == 1))
|
||||
THROW_IE_EXCEPTION << layer->name << " 'Axis' tensor must be 1D array of 1 element";
|
||||
if (!(inShapes[AXIS].size() == 0 ||
|
||||
(inShapes[AXIS].size() == 1 && inShapes[AXIS][0] == 1)))
|
||||
THROW_IE_EXCEPTION << layer->name << " 'Axis' tensor must be scalar, or 1D array of 1 element";
|
||||
|
||||
if (inShapes[INDICES].size() != inShapes[DATA].size())
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect number of 'indexes' tensors dimension";
|
||||
@ -3167,12 +3169,12 @@ void ScatterElementsUpdateValidator::checkShapes(const CNNLayer* layer, const ve
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect number of 'updates' tensors dimension";
|
||||
|
||||
Precision inIdxPrecision = layer->insData[INDICES].lock()->getTensorDesc().getPrecision();
|
||||
if (inIdxPrecision != Precision::FP32 && inIdxPrecision != Precision::I32 && inIdxPrecision != Precision::I64)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indices' precision. Only FP32 or I32 or I64 are supported!";
|
||||
if (inIdxPrecision != Precision::I32 && inIdxPrecision != Precision::I64)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Indices' precision. Only I32 or I64 are supported!";
|
||||
|
||||
Precision inAxisPrecision = layer->insData[AXIS].lock()->getTensorDesc().getPrecision();
|
||||
if (inAxisPrecision != Precision::FP32 && inAxisPrecision != Precision::I32 && inIdxPrecision != Precision::I64)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Axis' precision. Only FP32 or I32 or I64 are supported!";
|
||||
if (inAxisPrecision != Precision::I32 && inIdxPrecision != Precision::I64)
|
||||
THROW_IE_EXCEPTION << layer->name << " Incorrect input 'Axis' precision. Only I32 or I64 are supported!";
|
||||
|
||||
if (layer->insData[DATA].lock()->getTensorDesc().getPrecision() !=
|
||||
layer->insData[UPDATES].lock()->getTensorDesc().getPrecision())
|
||||
|
@ -331,7 +331,6 @@ private:
|
||||
<layer id="3" name="axis" type="Input">
|
||||
<output>
|
||||
<port id="0" precision="I32">
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
@ -347,7 +346,6 @@ private:
|
||||
__UPDATES_DIMS__
|
||||
</port>
|
||||
<port id="3" precision="I32">
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
|
@ -76,8 +76,8 @@ protected:
|
||||
SizeVector outputShape = inputShape; // copy
|
||||
const int outputNDims = inputNDims;
|
||||
|
||||
SizeVector axisShape = { 1 };
|
||||
const int axisNDims = 1;
|
||||
SizeVector axisShape = {};
|
||||
const int axisNDims = 0;
|
||||
|
||||
// E.g.:
|
||||
// {N, C, H, W} could be shape of `input` and `output`
|
||||
@ -251,8 +251,11 @@ private:
|
||||
const std::vector<ie_fp16>& updatesData,
|
||||
const std::vector<int32_t>& axisData) {
|
||||
// yet we only support axis == 0
|
||||
IE_ASSERT(axisShape.size() == 1);
|
||||
IE_ASSERT(axisShape[0] == 1);
|
||||
IE_ASSERT(axisShape.size() == 0 ||
|
||||
axisShape.size() == 1);
|
||||
if (axisShape.size() > 0) {
|
||||
IE_ASSERT(axisShape[0] == 1);
|
||||
}
|
||||
IE_ASSERT(axisData[0] == 0);
|
||||
|
||||
// copy `input` to `output`
|
||||
@ -400,7 +403,6 @@ private:
|
||||
<layer id="3" name="axis" type="Input">
|
||||
<output>
|
||||
<port id="0" precision="I32">
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</output>
|
||||
</layer>
|
||||
@ -416,7 +418,6 @@ private:
|
||||
__UPDATES_DIMS__
|
||||
</port>
|
||||
<port id="3" precision="I32">
|
||||
<dim>1</dim>
|
||||
</port>
|
||||
</input>
|
||||
<output>
|
||||
|
Loading…
Reference in New Issue
Block a user