Optimization of ScatterElementsUpdate ref impl (#18313)
Co-authored-by: Katarzyna Mitrus <katarzyna.mitrus@intel.com> Co-authored-by: Michal Lukaszewski <michal.lukaszewski@intel.com>
This commit is contained in:
@@ -149,10 +149,7 @@ typename std::enable_if<std::is_floating_point<T>::value || std::is_class<T>::va
|
||||
|
||||
template <typename T>
|
||||
typename std::enable_if<std::is_integral<T>::value, T>::type arithmetic_mean(const T accumulator, const int32_t N) {
|
||||
const auto old_mode = std::fegetround();
|
||||
std::fesetround(FE_DOWNWARD);
|
||||
const T value = static_cast<T>(std::nearbyint(static_cast<double>(accumulator) / N));
|
||||
std::fesetround(old_mode);
|
||||
return value;
|
||||
}
|
||||
|
||||
@@ -165,6 +162,25 @@ size_t normalize_index(const T idx, const size_t dim_value) {
|
||||
}
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
struct RoundingDirectionGuard {
|
||||
RoundingDirectionGuard() {
|
||||
if (std::is_integral<T>::value) {
|
||||
m_original_mode = std::fegetround();
|
||||
std::fesetround(FE_DOWNWARD);
|
||||
}
|
||||
}
|
||||
|
||||
~RoundingDirectionGuard() {
|
||||
if (std::is_integral<T>::value) {
|
||||
std::fesetround(m_original_mode);
|
||||
}
|
||||
}
|
||||
|
||||
private:
|
||||
decltype(std::fegetround()) m_original_mode;
|
||||
};
|
||||
|
||||
template <typename DataType, typename IndicesType>
|
||||
void scatter_elem_update_with_reduction(const DataType* input_data,
|
||||
const IndicesType* indices,
|
||||
@@ -221,6 +237,9 @@ void scatter_elem_update_with_reduction(const DataType* input_data,
|
||||
}
|
||||
|
||||
if (reduction_type == Reduction::MEAN) {
|
||||
// this object will change the rounding mode only for integer types which is required to match torch
|
||||
// upon destruction the previously used rounding mode will be restored
|
||||
RoundingDirectionGuard<DataType> rounding_guard;
|
||||
for (const auto& counter : mean_reduction_counters) {
|
||||
// include the initial value in the arithmetic mean divisor (if needed)
|
||||
const auto N = counter.second + static_cast<int32_t>(use_init_val);
|
||||
|
||||
Reference in New Issue
Block a user