Squeeze shapes when detect gather
This commit is contained in:
parent
7b1bbf77ee
commit
faf2f253c6
@ -47,7 +47,16 @@ size_t GetSliceNum(const Shape& transpose_order) {
|
||||
return slice_count;
|
||||
}
|
||||
|
||||
std::vector<int64_t> CreateGatherIndices(const Shape& transpose_input_shape,
|
||||
inline size_t GetFirstValuableDimId(const ov::Shape& shape) {
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if (shape[i] != 1) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<size_t> CreateGatherIndices(const Shape& transpose_input_shape,
|
||||
const Shape& reshape_output_shape,
|
||||
const Shape& transpose_order) {
|
||||
const size_t slice_0_end = FindEndOfSlice(transpose_order, 0);
|
||||
@ -73,7 +82,7 @@ std::vector<int64_t> CreateGatherIndices(const Shape& transpose_input_shape,
|
||||
return result *= transpose_input_shape[order_value];
|
||||
});
|
||||
|
||||
std::vector<int64_t> gather_indices_value(reshape_output_shape.back());
|
||||
std::vector<size_t> gather_indices_value(helper::SqueezeShape(reshape_output_shape).back());
|
||||
for (size_t i = 0; i < gather_indices_value.size(); ++i) {
|
||||
gather_indices_value[i] = transpose_part_0 * (i % transpose_part_1) + i / transpose_part_1;
|
||||
}
|
||||
@ -85,7 +94,7 @@ NodePair SinkForward(NodePtr transpose, std::shared_ptr<Constant> transpose_cons
|
||||
const auto gather_indices_value = CreateGatherIndices(transpose->get_input_shape(0),
|
||||
reshape->get_output_shape(0),
|
||||
transpose_constant->get_axis_vector_val());
|
||||
const int64_t gather_axis_value = reshape->get_output_shape(0).size() - 1;
|
||||
const int64_t gather_axis_value = GetFirstValuableDimId(reshape->get_output_shape(0));
|
||||
|
||||
auto reshape_new = reshape->clone_with_new_inputs({transpose->input_value(0), reshape->input_value(1)});
|
||||
|
||||
@ -111,8 +120,7 @@ Shape TransposeShape(const Shape& shape, AxisVector transpose_axis) {
|
||||
}
|
||||
|
||||
NodePair SinkBackward(NodePtr transpose, std::shared_ptr<Constant> transpose_constant, NodePtr reshape) {
|
||||
const int64_t gather_axis_value = reshape->get_input_shape(0).size() - 1;
|
||||
|
||||
const int64_t gather_axis_value = GetFirstValuableDimId(reshape->get_input_shape(0));
|
||||
const auto gather_indices_value =
|
||||
CreateGatherIndices(TransposeShape(transpose->get_output_shape(0), transpose_constant->get_axis_vector_val()),
|
||||
reshape->get_input_shape(0),
|
||||
|
Loading…
Reference in New Issue
Block a user