Fix fill_tensor_random for real values: (#19095)
This commit is contained in:
parent
545c5bbde1
commit
5379068610
@ -266,8 +266,7 @@ void inline fill_random_unique_sequence(T* rawBlobDataPtr,
|
|||||||
* - With k = 2 numbers resolution will 1/2 so outputs only .0 or .50
|
* - With k = 2 numbers resolution will 1/2 so outputs only .0 or .50
|
||||||
* - With k = 4 numbers resolution will 1/4 so outputs only .0 .25 .50 0.75 and etc.
|
* - With k = 4 numbers resolution will 1/4 so outputs only .0 .25 .50 0.75 and etc.
|
||||||
*/
|
*/
|
||||||
void fill_tensor_random(ov::Tensor& tensor, const uint32_t range = 10, int32_t start_from = 0,
|
void fill_tensor_random(ov::Tensor& tensor, const double range = 10, const double start_from = 0, const int32_t k = 1, const int seed = 1);
|
||||||
const int32_t k = 1, const int seed = 1);
|
|
||||||
|
|
||||||
/** @brief Fill blob with random data.
|
/** @brief Fill blob with random data.
|
||||||
*
|
*
|
||||||
|
@ -357,8 +357,7 @@ void fill_data_with_broadcast(ov::Tensor& tensor, size_t axis, std::vector<float
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<ov::element::Type_t DT>
|
template<ov::element::Type_t DT>
|
||||||
void fill_tensor_random(ov::Tensor& tensor, const uint32_t range = 10, int32_t start_from = 0,
|
void fill_tensor_random(ov::Tensor& tensor, const uint32_t range, const int32_t start_from, const int32_t k, const int seed) {
|
||||||
const int32_t k = 1, const int seed = 1) {
|
|
||||||
using T = typename ov::element_type_traits<DT>::value_type;
|
using T = typename ov::element_type_traits<DT>::value_type;
|
||||||
auto *rawBlobDataPtr = static_cast<T*>(tensor.data());
|
auto *rawBlobDataPtr = static_cast<T*>(tensor.data());
|
||||||
if (DT == ov::element::u4 || DT == ov::element::i4 ||
|
if (DT == ov::element::u4 || DT == ov::element::i4 ||
|
||||||
@ -370,12 +369,11 @@ void fill_tensor_random(ov::Tensor& tensor, const uint32_t range = 10, int32_t s
|
|||||||
}
|
}
|
||||||
|
|
||||||
template<ov::element::Type_t DT>
|
template<ov::element::Type_t DT>
|
||||||
void fill_tensor_random_float(ov::Tensor& tensor, const uint32_t range, int32_t start_from, const int32_t k,
|
void fill_tensor_random_float(ov::Tensor& tensor, const double range, const double start_from, const int32_t k, const int seed) {
|
||||||
const int seed = 1) {
|
|
||||||
using T = typename ov::element_type_traits<DT>::value_type;
|
using T = typename ov::element_type_traits<DT>::value_type;
|
||||||
std::default_random_engine random(seed);
|
std::default_random_engine random(seed);
|
||||||
// 1/k is the resolution of the floating point numbers
|
// 1/k is the resolution of the floating point numbers
|
||||||
std::uniform_int_distribution<int32_t> distribution(k * start_from, k * (start_from + range));
|
std::uniform_real_distribution<double> distribution(k * start_from, k * (start_from + range));
|
||||||
|
|
||||||
auto *rawBlobDataPtr = static_cast<T*>(tensor.data());
|
auto *rawBlobDataPtr = static_cast<T*>(tensor.data());
|
||||||
for (size_t i = 0; i < tensor.get_size(); i++) {
|
for (size_t i = 0; i < tensor.get_size(); i++) {
|
||||||
@ -391,11 +389,10 @@ void fill_tensor_random_float(ov::Tensor& tensor, const uint32_t range, int32_t
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
void fill_tensor_random(ov::Tensor& tensor, const uint32_t range, int32_t start_from,
|
void fill_tensor_random(ov::Tensor& tensor, const double range, const double start_from, const int32_t k, const int seed) {
|
||||||
const int32_t k, const int seed) {
|
|
||||||
auto element_type = tensor.get_element_type();
|
auto element_type = tensor.get_element_type();
|
||||||
|
|
||||||
#define CASE(X) case X: fill_tensor_random<X>(tensor, range, start_from, k, seed); break;
|
#define CASE(X) case X: fill_tensor_random<X>(tensor, static_cast<uint32_t>(range), static_cast<int32_t>(start_from), k, seed); break;
|
||||||
#define CASE_FLOAT(X) case X: fill_tensor_random_float<X>(tensor, range, start_from, k, seed); break;
|
#define CASE_FLOAT(X) case X: fill_tensor_random_float<X>(tensor, range, start_from, k, seed); break;
|
||||||
|
|
||||||
switch (element_type) {
|
switch (element_type) {
|
||||||
|
Loading…
Reference in New Issue
Block a user