fix input issuse of ScatterNDUpdate conformance test (#16406)
* fix input issuse of ScatterNDUpdate conformance test Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> * fix typo and optimize temporary variable Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com> --------- Signed-off-by: Hu Yuan2 <yuan2.hu@intel.com>
This commit is contained in:
parent
ca1102b855
commit
2255bb25fd
@ -827,6 +827,91 @@ ov::runtime::Tensor generate(const std::shared_ptr<ngraph::op::v8::Softmax>& nod
|
||||
return generate(std::dynamic_pointer_cast<ov::Node>(node), port, elemType, targetShape);
|
||||
}
|
||||
|
||||
ov::runtime::Tensor generate(const
|
||||
std::shared_ptr<ngraph::op::v3::ScatterNDUpdate>& node,
|
||||
size_t port,
|
||||
const ov::element::Type& elemType,
|
||||
const ov::Shape& targetShape) {
|
||||
// when fill indices
|
||||
if (port == 1) {
|
||||
auto srcShape = node->get_input_shape(0);
|
||||
// the data in indices must be unique.
|
||||
// so need to select part data from total collection
|
||||
// Calculate the collection size
|
||||
int k = targetShape[targetShape.size() - 1];
|
||||
int totalSize = 1;
|
||||
for (int i = 0; i < k; i++) {
|
||||
totalSize *= srcShape[i];
|
||||
}
|
||||
size_t indShapeSize = ngraph::shape_size(targetShape);
|
||||
// Calculate the size of part data
|
||||
int selectNums = indShapeSize / k;
|
||||
// create total collection
|
||||
std::vector<int> collection(totalSize);
|
||||
for (int i = 0; i < totalSize; i++) {
|
||||
collection[i] = i;
|
||||
}
|
||||
// select part data from collection
|
||||
// the last selectNums data in collection are what want to be filled into tensor
|
||||
testing::internal::Random random(1);
|
||||
int r = 0;
|
||||
int tmp = 0;
|
||||
for (int i = 0, y = totalSize; i < selectNums; i++, y--) {
|
||||
r = random.Generate(y);
|
||||
// switch y and r
|
||||
tmp = collection[y - 1];
|
||||
collection[y - 1] = collection[r];
|
||||
collection[r] = tmp;
|
||||
}
|
||||
// if the shape of source data is (a ,b ,c)
|
||||
// the strides is (bc, c, 1)
|
||||
std::vector<int> strides;
|
||||
int stride = 1;
|
||||
strides.push_back(stride);
|
||||
for (int i = k - 1; i > 0; i--) {
|
||||
stride *= srcShape[i];
|
||||
strides.push_back(stride);
|
||||
}
|
||||
std::reverse(strides.begin(), strides.end());
|
||||
// create tensor and fill function
|
||||
auto tensor = ov::Tensor{elemType, targetShape};
|
||||
auto fill_data = [&elemType, &tensor](int offset, int value) {
|
||||
switch (elemType) {
|
||||
case ov::element::Type_t::i32: {
|
||||
auto data =
|
||||
tensor.data<element_type_traits<ov::element::Type_t::i32>::value_type>();
|
||||
data[offset] = value;
|
||||
break;
|
||||
}
|
||||
case ov::element::Type_t::i64: {
|
||||
auto data =
|
||||
tensor.data<element_type_traits<ov::element::Type_t::i64>::value_type>();
|
||||
data[offset] = value;
|
||||
break;
|
||||
}
|
||||
default:
|
||||
throw std::runtime_error("indices type should be int32 or int64");
|
||||
}
|
||||
};
|
||||
// start to fill data
|
||||
int index = 0;
|
||||
int tmpNum = 0;
|
||||
for (int i = totalSize - selectNums, y = 0; i < totalSize; i++, y = y + k) {
|
||||
tmpNum = collection[i];
|
||||
for (int z = 0; z < k; z++) {
|
||||
//Calculate index of dims
|
||||
index = tmpNum / strides[z];
|
||||
tmpNum = tmpNum % strides[z];
|
||||
fill_data(y + z, index);
|
||||
}
|
||||
}
|
||||
return tensor;
|
||||
} else {
|
||||
return generate(std::dynamic_pointer_cast<ov::Node>(node), port, elemType,
|
||||
targetShape);
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
ov::runtime::Tensor generateInput(const std::shared_ptr<ov::Node>& node,
|
||||
size_t port,
|
||||
|
Loading…
Reference in New Issue
Block a user