Merge branch 'gna_layout_debug' into gather_sinking_reshape
This commit is contained in:
commit
894defdcc9
@ -11,6 +11,7 @@
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/utils/gather_sinking_utils.hpp"
|
||||
#include "transformations/utils/transformation_helper.hpp"
|
||||
#include "log/debug.hpp"
|
||||
|
||||
using namespace ov::intel_gna::pass;
|
||||
using namespace ov;
|
||||
@ -23,20 +24,6 @@ namespace {
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
using NodePair = std::pair<NodePtr, NodePtr>;
|
||||
|
||||
/*
|
||||
* Finds the end of the slice transpose order. Slice here is the subvector
|
||||
* with increasing values transpose_order[i + 1] == transpose_order[i] + 1
|
||||
*/
|
||||
size_t FindEndOfSlice(const Shape& transpose_order, size_t start_idx) {
|
||||
size_t slice_end = start_idx;
|
||||
for (size_t i = start_idx + 1; i < transpose_order.size(); ++i) {
|
||||
if (transpose_order[i] != transpose_order[slice_end] + 1)
|
||||
break;
|
||||
slice_end = i;
|
||||
}
|
||||
return slice_end;
|
||||
}
|
||||
|
||||
size_t GetSliceNum(const Shape& transpose_order) {
|
||||
size_t slice_count = 0;
|
||||
for (size_t i = 0; i < transpose_order.size(); ++i) {
|
||||
@ -56,44 +43,47 @@ inline size_t GetFirstValuableDimId(const ov::Shape& shape) {
|
||||
return 0;
|
||||
}
|
||||
|
||||
std::vector<size_t> CreateGatherIndices(const Shape& transpose_input_shape,
|
||||
const Shape& reshape_output_shape,
|
||||
const Shape& transpose_order) {
|
||||
const size_t slice_0_end = FindEndOfSlice(transpose_order, 0);
|
||||
const size_t slice_1_start = slice_0_end + 1;
|
||||
const size_t slice_1_end = FindEndOfSlice(transpose_order, slice_1_start);
|
||||
const size_t slice_2_start = slice_1_end + 1;
|
||||
|
||||
if (slice_0_end >= transpose_input_shape.size() || slice_1_start >= transpose_input_shape.size() ||
|
||||
slice_1_end >= transpose_input_shape.size() || slice_2_start >= transpose_input_shape.size()) {
|
||||
return {};
|
||||
std::vector<size_t> CreateGatherIndices(const Shape& input_shape,
|
||||
const Shape& order) {
|
||||
if (input_shape.size() < 2 || input_shape.size() > 4) {
|
||||
THROW_GNA_EXCEPTION << "Usupported shape size: " << input_shape.size();
|
||||
}
|
||||
|
||||
const int64_t transpose_part_0 = std::accumulate(transpose_order.begin() + slice_1_start,
|
||||
transpose_order.begin() + slice_1_end + 1,
|
||||
1,
|
||||
[&transpose_input_shape](int64_t result, int64_t order_value) {
|
||||
return result *= transpose_input_shape[order_value];
|
||||
});
|
||||
const int64_t transpose_part_1 = std::accumulate(transpose_order.begin() + slice_2_start,
|
||||
transpose_order.end(),
|
||||
1,
|
||||
[&transpose_input_shape](int64_t result, int64_t order_value) {
|
||||
return result *= transpose_input_shape[order_value];
|
||||
});
|
||||
ov::Shape input_shape_4d = input_shape;
|
||||
ov::Shape order_4d = order;
|
||||
// Just to simplify the code we transform all shapes to 4d by adding 1 dimentions at the end
|
||||
while (input_shape_4d.size() < 4) {
|
||||
input_shape_4d.push_back(1);
|
||||
order_4d.push_back(order_4d.size());
|
||||
}
|
||||
ov::Shape output_shape_4d = helper::TransposeShape(input_shape_4d, order_4d);
|
||||
|
||||
std::vector<size_t> gather_indices_value(helper::SqueezeShape(reshape_output_shape).back());
|
||||
for (size_t i = 0; i < gather_indices_value.size(); ++i) {
|
||||
gather_indices_value[i] = transpose_part_0 * (i % transpose_part_1) + i / transpose_part_1;
|
||||
// common case when shape is 4d
|
||||
std::vector<size_t> xyz_4d = { input_shape_4d[3] * input_shape_4d[2] * input_shape_4d[1],
|
||||
input_shape_4d[3] * input_shape_4d[2],
|
||||
input_shape_4d[3],
|
||||
1 };
|
||||
|
||||
std::vector<size_t> xyz = helper::TransposeShape(xyz_4d, order);
|
||||
std::vector<size_t> gather_order;
|
||||
|
||||
for (size_t n = 0; n < output_shape_4d[0]; ++n) {
|
||||
for (size_t i = 0; i < output_shape_4d[1]; ++i) {
|
||||
for (size_t j = 0; j < output_shape_4d[2]; ++j) {
|
||||
for (size_t k = 0; k < output_shape_4d[3]; ++k) {
|
||||
gather_order.push_back(n * xyz[0] + i * xyz[1] + j * xyz[2] + k * xyz[3]);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return gather_indices_value;
|
||||
return gather_order;
|
||||
}
|
||||
|
||||
NodePair SinkForward(NodePtr transpose, std::shared_ptr<Constant> transpose_constant, NodePtr reshape) {
|
||||
const auto gather_indices_value = CreateGatherIndices(transpose->get_input_shape(0),
|
||||
reshape->get_output_shape(0),
|
||||
transpose_constant->get_axis_vector_val());
|
||||
|
||||
const int64_t gather_axis_value = GetFirstValuableDimId(reshape->get_output_shape(0));
|
||||
|
||||
auto reshape_new = reshape->clone_with_new_inputs({transpose->input_value(0), reshape->input_value(1)});
|
||||
@ -111,19 +101,10 @@ NodePair SinkForward(NodePtr transpose, std::shared_ptr<Constant> transpose_cons
|
||||
return std::make_pair(reshape_new, gather);
|
||||
}
|
||||
|
||||
Shape TransposeShape(const Shape& shape, AxisVector transpose_axis) {
|
||||
Shape transposed(shape.size());
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
transposed[i] = shape[transpose_axis[i]];
|
||||
}
|
||||
return transposed;
|
||||
}
|
||||
|
||||
NodePair SinkBackward(NodePtr transpose, std::shared_ptr<Constant> transpose_constant, NodePtr reshape) {
|
||||
const int64_t gather_axis_value = GetFirstValuableDimId(reshape->get_input_shape(0));
|
||||
const auto gather_indices_value =
|
||||
CreateGatherIndices(TransposeShape(transpose->get_output_shape(0), transpose_constant->get_axis_vector_val()),
|
||||
reshape->get_input_shape(0),
|
||||
CreateGatherIndices(transpose->get_input_shape(0),
|
||||
transpose_constant->get_axis_vector_val());
|
||||
|
||||
auto gather_axis = std::make_shared<Constant>(element::i64, Shape{}, gather_axis_value);
|
||||
|
@ -9,6 +9,7 @@
|
||||
#include <ngraph/rt_info.hpp>
|
||||
#include "ops/gna_convolution.hpp"
|
||||
#include "ops/gna_max_pool.hpp"
|
||||
#include "log/debug.hpp"
|
||||
|
||||
using namespace ov::opset7;
|
||||
|
||||
@ -123,11 +124,26 @@ void RemoveSingleInputNodeFromFunction(std::shared_ptr<ov::Node> node) {
|
||||
}
|
||||
|
||||
ov::Shape SqueezeShape(const ov::Shape& shape) {
|
||||
ov::Shape squeezed_shape;
|
||||
std::copy_if(shape.begin(), shape.end(), std::back_inserter(squeezed_shape), [](size_t x) {
|
||||
return x != 1;
|
||||
});
|
||||
return squeezed_shape;
|
||||
auto comp = [](size_t x) { return x != 1; };
|
||||
|
||||
auto start_it = std::find_if(shape.begin(), shape.end(), comp);
|
||||
auto end_it = std::find_if(shape.rbegin(), shape.rend(), comp).base();
|
||||
if (start_it == shape.end() || end_it == shape.end()) {
|
||||
return ov::Shape(shape.begin(), shape.end());
|
||||
}
|
||||
return ov::Shape(start_it, end_it);
|
||||
}
|
||||
|
||||
ov::Shape TransposeShape(const ov::Shape& shape, std::vector<size_t> order) {
|
||||
if (shape.size() != order.size()) {
|
||||
THROW_GNA_EXCEPTION << "Sizes of the shape " << shape.size()
|
||||
<< " and transpose axis " << order.size() << " are different";
|
||||
}
|
||||
ov::Shape transposed(shape.size());
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
transposed[i] = shape[order[i]];
|
||||
}
|
||||
return transposed;
|
||||
}
|
||||
|
||||
} // namespace helper
|
||||
|
@ -124,6 +124,14 @@ void RemoveSingleInputNodeFromFunction(std::shared_ptr<ov::Node> node);
|
||||
*/
|
||||
ov::Shape SqueezeShape(const ov::Shape& shape);
|
||||
|
||||
/**
|
||||
* @brief Transpose shape
|
||||
* @param shape the shape to be transposed
|
||||
* @param order the permutation to apply to the axes of the input shape
|
||||
* @return transposed shape
|
||||
*/
|
||||
ov::Shape TransposeShape(const ov::Shape& shape, std::vector<size_t> order);
|
||||
|
||||
} // namespace helper
|
||||
} // namespace pass
|
||||
} // namespace intel_gna
|
||||
|
Loading…
Reference in New Issue
Block a user