Fixed Gather indexes

This commit is contained in:
Mikhail Ryzhov 2023-03-22 14:55:10 +01:00
parent 02abf9b1f0
commit d105cfcc68
3 changed files with 62 additions and 57 deletions

View File

@ -11,6 +11,7 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/gather_sinking_utils.hpp"
#include "transformations/utils/transformation_helper.hpp"
#include "log/debug.hpp"
using namespace ov::intel_gna::pass;
using namespace ov;
@ -23,20 +24,6 @@ namespace {
using NodePtr = std::shared_ptr<ov::Node>;
using NodePair = std::pair<NodePtr, NodePtr>;
/*
* Finds the end of the slice transpose order. Slice here is the subvector
* with increasing values transpose_order[i + 1] == transpose_order[i] + 1
*/
size_t FindEndOfSlice(const Shape& transpose_order, size_t start_idx) {
size_t slice_end = start_idx;
for (size_t i = start_idx + 1; i < transpose_order.size(); ++i) {
if (transpose_order[i] != transpose_order[slice_end] + 1)
break;
slice_end = i;
}
return slice_end;
}
size_t GetSliceNum(const Shape& transpose_order) {
size_t slice_count = 0;
for (size_t i = 0; i < transpose_order.size(); ++i) {
@ -56,44 +43,47 @@ inline size_t GetFirstValuableDimId(const ov::Shape& shape) {
return 0;
}
std::vector<size_t> CreateGatherIndices(const Shape& transpose_input_shape,
const Shape& reshape_output_shape,
const Shape& transpose_order) {
const size_t slice_0_end = FindEndOfSlice(transpose_order, 0);
const size_t slice_1_start = slice_0_end + 1;
const size_t slice_1_end = FindEndOfSlice(transpose_order, slice_1_start);
const size_t slice_2_start = slice_1_end + 1;
if (slice_0_end >= transpose_input_shape.size() || slice_1_start >= transpose_input_shape.size() ||
slice_1_end >= transpose_input_shape.size() || slice_2_start >= transpose_input_shape.size()) {
return {};
std::vector<size_t> CreateGatherIndices(const Shape& input_shape,
const Shape& order) {
if (input_shape.size() < 2 || input_shape.size() > 4) {
THROW_GNA_EXCEPTION << "Usupported shape size: " << input_shape.size();
}
const int64_t transpose_part_0 = std::accumulate(transpose_order.begin() + slice_1_start,
transpose_order.begin() + slice_1_end + 1,
1,
[&transpose_input_shape](int64_t result, int64_t order_value) {
return result *= transpose_input_shape[order_value];
});
const int64_t transpose_part_1 = std::accumulate(transpose_order.begin() + slice_2_start,
transpose_order.end(),
1,
[&transpose_input_shape](int64_t result, int64_t order_value) {
return result *= transpose_input_shape[order_value];
});
ov::Shape input_shape_4d = input_shape;
ov::Shape order_4d = order;
// Just to simplify the code we transform all shapes to 4d by adding 1 dimentions at the end
while (input_shape_4d.size() < 4) {
input_shape_4d.push_back(1);
order_4d.push_back(order_4d.size());
}
ov::Shape output_shape_4d = helper::TransposeShape(input_shape_4d, order_4d);
std::vector<size_t> gather_indices_value(helper::SqueezeShape(reshape_output_shape).back());
for (size_t i = 0; i < gather_indices_value.size(); ++i) {
gather_indices_value[i] = transpose_part_0 * (i % transpose_part_1) + i / transpose_part_1;
// common case when shape is 4d
std::vector<size_t> xyz_4d = { input_shape_4d[3] * input_shape_4d[2] * input_shape_4d[1],
input_shape_4d[3] * input_shape_4d[2],
input_shape_4d[3],
1 };
std::vector<size_t> xyz = helper::TransposeShape(xyz_4d, order);
std::vector<size_t> gather_order;
for (size_t n = 0; n < output_shape_4d[0]; ++n) {
for (size_t i = 0; i < output_shape_4d[1]; ++i) {
for (size_t j = 0; j < output_shape_4d[2]; ++j) {
for (size_t k = 0; k < output_shape_4d[3]; ++k) {
gather_order.push_back(n * xyz[0] + i * xyz[1] + j * xyz[2] + k * xyz[3]);
}
}
}
}
return gather_indices_value;
return gather_order;
}
NodePair SinkForward(NodePtr transpose, std::shared_ptr<Constant> transpose_constant, NodePtr reshape) {
const auto gather_indices_value = CreateGatherIndices(transpose->get_input_shape(0),
reshape->get_output_shape(0),
transpose_constant->get_axis_vector_val());
const int64_t gather_axis_value = GetFirstValuableDimId(reshape->get_output_shape(0));
auto reshape_new = reshape->clone_with_new_inputs({transpose->input_value(0), reshape->input_value(1)});
@ -111,19 +101,10 @@ NodePair SinkForward(NodePtr transpose, std::shared_ptr<Constant> transpose_cons
return std::make_pair(reshape_new, gather);
}
Shape TransposeShape(const Shape& shape, AxisVector transpose_axis) {
Shape transposed(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
transposed[i] = shape[transpose_axis[i]];
}
return transposed;
}
NodePair SinkBackward(NodePtr transpose, std::shared_ptr<Constant> transpose_constant, NodePtr reshape) {
const int64_t gather_axis_value = GetFirstValuableDimId(reshape->get_input_shape(0));
const auto gather_indices_value =
CreateGatherIndices(TransposeShape(transpose->get_output_shape(0), transpose_constant->get_axis_vector_val()),
reshape->get_input_shape(0),
CreateGatherIndices(transpose->get_input_shape(0),
transpose_constant->get_axis_vector_val());
auto gather_axis = std::make_shared<Constant>(element::i64, Shape{}, gather_axis_value);

View File

@ -9,6 +9,7 @@
#include <ngraph/rt_info.hpp>
#include "ops/gna_convolution.hpp"
#include "ops/gna_max_pool.hpp"
#include "log/debug.hpp"
using namespace ov::opset7;
@ -123,11 +124,26 @@ void RemoveSingleInputNodeFromFunction(std::shared_ptr<ov::Node> node) {
}
ov::Shape SqueezeShape(const ov::Shape& shape) {
ov::Shape squeezed_shape;
std::copy_if(shape.begin(), shape.end(), std::back_inserter(squeezed_shape), [](size_t x) {
return x != 1;
});
return squeezed_shape;
auto comp = [](size_t x) { return x != 1; };
auto start_it = std::find_if(shape.begin(), shape.end(), comp);
auto end_it = std::find_if(shape.rbegin(), shape.rend(), comp).base();
if (start_it == shape.end() || end_it == shape.end() || start_it < end_it) {
return ov::Shape(shape.begin(), shape.end());
}
return ov::Shape(start_it, end_it);
}
ov::Shape TransposeShape(const ov::Shape& shape, std::vector<size_t> order) {
if (shape.size() != order.size()) {
THROW_GNA_EXCEPTION << "Sizes of the shape " << shape.size()
<< " and transpose axis " << order.size() << " are different";
}
ov::Shape transposed(shape.size());
for (size_t i = 0; i < shape.size(); ++i) {
transposed[i] = shape[order[i]];
}
return transposed;
}
} // namespace helper

View File

@ -124,6 +124,14 @@ void RemoveSingleInputNodeFromFunction(std::shared_ptr<ov::Node> node);
*/
ov::Shape SqueezeShape(const ov::Shape& shape);
/**
* @brief Transpose shape
* @param shape the shape to be transposed
* @param order the permutation to apply to the axes of the input shape
* @return transposed shape
*/
ov::Shape TransposeShape(const ov::Shape& shape, std::vector<size_t> order);
} // namespace helper
} // namespace pass
} // namespace intel_gna