Fixed transpose reference (#4688)
This commit is contained in:
parent
9bd63176f8
commit
69b76f188c
@ -23,7 +23,7 @@ namespace ngraph
|
||||
template <typename T, typename U>
|
||||
void transpose(const T* arg, T* out, Shape arg_size, const U* axes_order = nullptr)
|
||||
{
|
||||
std::vector<size_t> range_vector;
|
||||
std::vector<U> range_vector;
|
||||
if (axes_order == nullptr)
|
||||
{
|
||||
range_vector.resize(arg_size.size());
|
||||
@ -31,19 +31,29 @@ namespace ngraph
|
||||
std::reverse(range_vector.begin(), range_vector.end());
|
||||
axes_order = range_vector.data();
|
||||
}
|
||||
size_t cnt = 0;
|
||||
for (size_t i = 0; i < arg_size.size(); ++i)
|
||||
|
||||
std::vector<size_t> input_strides(arg_size.size());
|
||||
std::vector<size_t> output_strides(arg_size.size());
|
||||
input_strides.back() = 1;
|
||||
output_strides.back() = 1;
|
||||
|
||||
for (int i = input_strides.size() - 2; i >= 0; i--)
|
||||
{
|
||||
size_t axes = axes_order[i];
|
||||
size_t start = 0;
|
||||
for (size_t j = 0; j < axes; ++j)
|
||||
input_strides[i] = input_strides[i + 1] * arg_size[i + 1];
|
||||
output_strides[i] = output_strides[i + 1] * arg_size[axes_order[i + 1]];
|
||||
}
|
||||
for (int i = 0; i < shape_size(arg_size); ++i)
|
||||
{
|
||||
size_t in_position = 0;
|
||||
size_t new_position = i;
|
||||
|
||||
for (int j = 0; j < arg_size.size(); ++j)
|
||||
{
|
||||
start += shape_size(arg_size[j]);
|
||||
}
|
||||
for (size_t j = start; j < start + shape_size(arg_size[axes]); ++j)
|
||||
{
|
||||
out[cnt++] = arg[j];
|
||||
in_position +=
|
||||
(new_position / output_strides[j]) * input_strides[axes_order[j]];
|
||||
new_position %= output_strides[j];
|
||||
}
|
||||
out[i] = arg[in_position];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user