Optimization of ScatterElementsUpdate ref impl (#18313)

Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com>
Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
Tomasz Dołbniak
2023-07-10 15:58:37 +02:00
committed by GitHub
parent 58de48a491
commit 975ba2a92b

View File

@@ -149,10 +149,7 @@ typename std::enable_if<std::is_floating_point<T>::value || std::is_class<T>::va
template <typename T>
typename std::enable_if<std::is_integral<T>::value, T>::type arithmetic_mean(const T accumulator, const int32_t N) {
const auto old_mode = std::fegetround();
std::fesetround(FE_DOWNWARD);
const T value = static_cast<T>(std::nearbyint(static_cast<double>(accumulator) / N));
std::fesetround(old_mode);
return value;
}
@@ -165,6 +162,25 @@ size_t normalize_index(const T idx, const size_t dim_value) {
}
}
template <typename T>
struct RoundingDirectionGuard {
RoundingDirectionGuard() {
if (std::is_integral<T>::value) {
m_original_mode = std::fegetround();
std::fesetround(FE_DOWNWARD);
}
}
~RoundingDirectionGuard() {
if (std::is_integral<T>::value) {
std::fesetround(m_original_mode);
}
}
private:
decltype(std::fegetround()) m_original_mode;
};
template <typename DataType, typename IndicesType>
void scatter_elem_update_with_reduction(const DataType* input_data,
const IndicesType* indices,
@@ -221,6 +237,9 @@ void scatter_elem_update_with_reduction(const DataType* input_data,
}
if (reduction_type == Reduction::MEAN) {
// this object will change the rounding mode only for integer types which is required to match torch
// upon destruction the previously used rounding mode will be restored
RoundingDirectionGuard<DataType> rounding_guard;
for (const auto& counter : mean_reduction_counters) {
// include the initial value in the arithmetic mean divisor (if needed)
const auto N = counter.second + static_cast<int32_t>(use_init_val);