Fixed transpose reference (#4688)

This commit is contained in:
Liubov Batanina 2021-03-26 16:27:47 +03:00 committed by GitHub
parent 9bd63176f8
commit 69b76f188c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -23,7 +23,7 @@ namespace ngraph
template <typename T, typename U>
void transpose(const T* arg, T* out, Shape arg_size, const U* axes_order = nullptr)
{
std::vector<size_t> range_vector;
std::vector<U> range_vector;
if (axes_order == nullptr)
{
range_vector.resize(arg_size.size());
@ -31,19 +31,29 @@ namespace ngraph
std::reverse(range_vector.begin(), range_vector.end());
axes_order = range_vector.data();
}
size_t cnt = 0;
for (size_t i = 0; i < arg_size.size(); ++i)
std::vector<size_t> input_strides(arg_size.size());
std::vector<size_t> output_strides(arg_size.size());
input_strides.back() = 1;
output_strides.back() = 1;
for (int i = input_strides.size() - 2; i >= 0; i--)
{
size_t axes = axes_order[i];
size_t start = 0;
for (size_t j = 0; j < axes; ++j)
input_strides[i] = input_strides[i + 1] * arg_size[i + 1];
output_strides[i] = output_strides[i + 1] * arg_size[axes_order[i + 1]];
}
for (int i = 0; i < shape_size(arg_size); ++i)
{
size_t in_position = 0;
size_t new_position = i;
for (int j = 0; j < arg_size.size(); ++j)
{
start += shape_size(arg_size[j]);
}
for (size_t j = start; j < start + shape_size(arg_size[axes]); ++j)
{
out[cnt++] = arg[j];
in_position +=
(new_position / output_strides[j]) * input_strides[axes_order[j]];
new_position %= output_strides[j];
}
out[i] = arg[in_position];
}
}
}