fix input issuse of ScatterNDUpdate conformance test (#16406)

* fix input issuse of ScatterNDUpdate conformance test

Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com>

* fix typo and optimize temporary variable

Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com>

---------

Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com>
This commit is contained in:
Yuan Hu 2023-04-25 17:00:22 +08:00 committed by GitHub
parent ca1102b855
commit 2255bb25fd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -827,6 +827,91 @@ ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v8::Softmax>& nod
return generate(std::dynamic_pointer_cast<ov::Node>(node), port, elemType, targetShape);
}
ov::runtime::Tensor generate(const
std::shared_ptr<ngraph::op::v3::ScatterNDUpdate>& node,
size_t port,
const ov::element::Type& elemType,
const ov::Shape& targetShape) {
// when fill indices
if (port == 1) {
auto srcShape = node->get_input_shape(0);
// the data in indices must be unique.
// so need to select part data from total collection
// Calculate the collection size
int k = targetShape[targetShape.size() - 1];
int totalSize = 1;
for (int i = 0; i < k; i++) {
totalSize *= srcShape[i];
}
size_t indShapeSize = ngraph::shape_size(targetShape);
// Calculate the size of part data
int selectNums = indShapeSize / k;
// create total collection
std::vector<int> collection(totalSize);
for (int i = 0; i < totalSize; i++) {
collection[i] = i;
}
// select part data from collection
// the last selectNums data in collection are what want to be filled into tensor
testing::internal::Random random(1);
int r = 0;
int tmp = 0;
for (int i = 0, y = totalSize; i < selectNums; i++, y--) {
r = random.Generate(y);
// switch y and r
tmp = collection[y - 1];
collection[y - 1] = collection[r];
collection[r] = tmp;
}
// if the shape of source data is (a ,b ,c)
// the strides is (bc, c, 1)
std::vector<int> strides;
int stride = 1;
strides.push_back(stride);
for (int i = k - 1; i > 0; i--) {
stride *= srcShape[i];
strides.push_back(stride);
}
std::reverse(strides.begin(), strides.end());
// create tensor and fill function
auto tensor = ov::Tensor{elemType, targetShape};
auto fill_data = [&elemType, &tensor](int offset, int value) {
switch (elemType) {
case ov::element::Type_t::i32: {
auto data =
tensor.data<element_type_traits<ov::element::Type_t::i32>::value_type>();
data[offset] = value;
break;
}
case ov::element::Type_t::i64: {
auto data =
tensor.data<element_type_traits<ov::element::Type_t::i64>::value_type>();
data[offset] = value;
break;
}
default:
throw std::runtime_error("indices type should be int32 or int64");
}
};
// start to fill data
int index = 0;
int tmpNum = 0;
for (int i = totalSize - selectNums, y = 0; i < totalSize; i++, y = y + k) {
tmpNum = collection[i];
for (int z = 0; z < k; z++) {
//Calculate index of dims
index = tmpNum / strides[z];
tmpNum = tmpNum % strides[z];
fill_data(y + z, index);
}
}
return tensor;
} else {
return generate(std::dynamic_pointer_cast<ov::Node>(node), port, elemType,
targetShape);
}
}
template<typename T>
ov::runtime::Tensor generateInput(const std::shared_ptr<ov::Node>& node,
size_t port,