diff --git a/ngraph/core/reference/include/ngraph/runtime/reference/transpose.hpp b/ngraph/core/reference/include/ngraph/runtime/reference/transpose.hpp index fc395791eca..f87690bc989 100644 --- a/ngraph/core/reference/include/ngraph/runtime/reference/transpose.hpp +++ b/ngraph/core/reference/include/ngraph/runtime/reference/transpose.hpp @@ -23,7 +23,7 @@ namespace ngraph template void transpose(const T* arg, T* out, Shape arg_size, const U* axes_order = nullptr) { - std::vector range_vector; + std::vector range_vector; if (axes_order == nullptr) { range_vector.resize(arg_size.size()); @@ -31,19 +31,29 @@ namespace ngraph std::reverse(range_vector.begin(), range_vector.end()); axes_order = range_vector.data(); } - size_t cnt = 0; - for (size_t i = 0; i < arg_size.size(); ++i) + + std::vector input_strides(arg_size.size()); + std::vector output_strides(arg_size.size()); + input_strides.back() = 1; + output_strides.back() = 1; + + for (int i = input_strides.size() - 2; i >= 0; i--) { - size_t axes = axes_order[i]; - size_t start = 0; - for (size_t j = 0; j < axes; ++j) + input_strides[i] = input_strides[i + 1] * arg_size[i + 1]; + output_strides[i] = output_strides[i + 1] * arg_size[axes_order[i + 1]]; + } + for (int i = 0; i < shape_size(arg_size); ++i) + { + size_t in_position = 0; + size_t new_position = i; + + for (int j = 0; j < arg_size.size(); ++j) { - start += shape_size(arg_size[j]); - } - for (size_t j = start; j < start + shape_size(arg_size[axes]); ++j) - { - out[cnt++] = arg[j]; + in_position += + (new_position / output_strides[j]) * input_strides[axes_order[j]]; + new_position %= output_strides[j]; } + out[i] = arg[in_position]; } } }