Squeeze shapes when detect gather

This commit is contained in:
Mikhail Ryzhov 2023-03-21 12:12:18 +01:00
parent 7b1bbf77ee
commit faf2f253c6

View File

@ -47,7 +47,16 @@ size_t GetSliceNum(const Shape& transpose_order) {
return slice_count;
}
std::vector<int64_t> CreateGatherIndices(const Shape& transpose_input_shape,
inline size_t GetFirstValuableDimId(const ov::Shape& shape) {
for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != 1) {
return i;
}
}
return 0;
}
std::vector<size_t> CreateGatherIndices(const Shape& transpose_input_shape,
const Shape& reshape_output_shape,
const Shape& transpose_order) {
const size_t slice_0_end = FindEndOfSlice(transpose_order, 0);
@ -73,7 +82,7 @@ std::vector<int64_t> CreateGatherIndices(const Shape& transpose_input_shape,
return result *= transpose_input_shape[order_value];
});
std::vector<int64_t> gather_indices_value(reshape_output_shape.back());
std::vector<size_t> gather_indices_value(helper::SqueezeShape(reshape_output_shape).back());
for (size_t i = 0; i < gather_indices_value.size(); ++i) {
gather_indices_value[i] = transpose_part_0 * (i % transpose_part_1) + i / transpose_part_1;
}
@ -85,7 +94,7 @@ NodePair SinkForward(NodePtr transpose, std::shared_ptr<Constant> transpose_cons
const auto gather_indices_value = CreateGatherIndices(transpose->get_input_shape(0),
reshape->get_output_shape(0),
transpose_constant->get_axis_vector_val());
const int64_t gather_axis_value = reshape->get_output_shape(0).size() - 1;
const int64_t gather_axis_value = GetFirstValuableDimId(reshape->get_output_shape(0));
auto reshape_new = reshape->clone_with_new_inputs({transpose->input_value(0), reshape->input_value(1)});
@ -111,8 +120,7 @@ Shape TransposeShape(const Shape& shape, AxisVector transpose_axis) {
}
NodePair SinkBackward(NodePtr transpose, std::shared_ptr<Constant> transpose_constant, NodePtr reshape) {
const int64_t gather_axis_value = reshape->get_input_shape(0).size() - 1;
const int64_t gather_axis_value = GetFirstValuableDimId(reshape->get_input_shape(0));
const auto gather_indices_value =
CreateGatherIndices(TransposeShape(transpose->get_output_shape(0), transpose_constant->get_axis_vector_val()),
reshape->get_input_shape(0),