This commit is contained in:
Evgeny Kotov 2023-07-26 06:36:30 +02:00 committed by GitHub
parent 7767af3529
commit 4d3601ac11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 53 additions and 25 deletions

View File

@ -262,31 +262,42 @@ TSGatherBackward::TSGatherBackward() {
order_val = GetOrderAfterReduction(axes_val, order_val);
}
std::shared_ptr<ov::op::v0::Constant> new_axis;
const auto& indices_rank_val = static_cast<size_t>(main_node->get_input_partial_shape(1).rank().get_length());
std::vector<size_t> new_transpose_order(order_val.size() - indices_rank_val + 1);
for (size_t i = 0, j = 0; i < order_val.size(); ++j) {
if (order_val[i] < axis) {
new_transpose_order[j] = order_val[i];
++i;
} else if (order_val[i] > axis) {
new_transpose_order[j] = order_val[i] - indices_rank_val + 1;
++i;
} else {
// the next `indices_rank_val` values have to be in ascending order
// these values will be replaced with a single axis
new_transpose_order[j] = order_val[i];
size_t prev_idx = i;
for (size_t k = 0; i < order_val.size() && k < indices_rank_val; ++i, ++k) {
if (order_val[i] != order_val[prev_idx]) {
if (success && squeeze) {
main_node->input(1).replace_source_output(squeeze->input_value(0));
std::vector<size_t> new_transpose_order;
if (indices_rank_val > 0) {
new_transpose_order.resize(order_val.size() - indices_rank_val + 1);
for (size_t i = 0, j = 0; i < order_val.size(); ++j) {
if (order_val[i] < axis) {
new_transpose_order[j] = order_val[i];
++i;
} else if (order_val[i] > axis) {
new_transpose_order[j] = order_val[i] - indices_rank_val + 1;
++i;
} else {
// the next `indices_rank_val` values have to be in ascending order
// these values will be replaced with a single axis
new_transpose_order[j] = order_val[i];
size_t prev_idx = i;
for (size_t k = 0; i < order_val.size() && k < indices_rank_val; ++i, ++k) {
if (order_val[i] != order_val[prev_idx]) {
if (success && squeeze) {
main_node->input(1).replace_source_output(squeeze->input_value(0));
}
return false;
}
return false;
prev_idx = i;
}
prev_idx = i;
}
}
} else {
const std::vector<size_t> axes_values = {axis};
new_transpose_order = GetOrderBeforeReduction(axes_values, order_val);
new_axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{1}, axis);
}
RemoveTransposeConsumers(main_node);
if (success) {
auto target_inputs = main_node->get_output_target_inputs(0);
@ -310,7 +321,9 @@ TSGatherBackward::TSGatherBackward() {
/* input_indexes= */ {0})) {
register_new_node(new_node);
}
auto new_axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{1}, reversed_transpose_order[axis]);
if (!new_axis) {
new_axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{1}, reversed_transpose_order[axis]);
}
copy_runtime_info(gather_axis, new_axis);
main_node->input(2).replace_source_output(new_axis);
main_node->validate_and_infer_types();

View File

@ -126,6 +126,7 @@ INSTANTIATE_TEST_SUITE_P(TSCommonGatherForward_3, TSTestFixture, test_forward_ga
struct GatherBackwardArguments {
OutputVector inputs_to_main;
Output<Node> new_Gather_first_input;
AxisVector new_transpose_order;
};
auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
@ -149,7 +150,15 @@ auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
new_out_vec[2] = test_arguments.new_Gather_first_input;
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {2}}};
auto new_transpose = [&test_arguments](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec = out_vec;
auto order = make_shared<Constant>(i32,
Shape{test_arguments.new_transpose_order.size()},
test_arguments.new_transpose_order);
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose, new_constant}, {{0}, {2}}};
test_case.model_ref.main_op = {CREATE_GATHER_FACTORY(Gather)};
test_case.model_ref.model_template = create_model;
@ -157,10 +166,15 @@ auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
};
vector<GatherBackwardArguments> tests_arguments_bw{
{{{parameter(f32, {3, 4, 5, 6}), constant<int>(i32, {2}, {0, 2}), constant<int>(i32, {1}, {2})}},
constant<int>(i32, {1}, {1})}};
{{parameter(f32, {3, 4, 5, 6}), constant<int>(i32, {2}, {0, 2}), constant<int>(i32, {1}, {2})},
constant<int>(i32, {1}, {1}),
AxisVector{3, 2, 1, 0}},
{{parameter(f32, {1, 2, 16, 3, 64}), constant<int>(i32, {}, {0}), constant<int>(i32, {1}, {3})},
constant<int>(i32, {1}, {3}),
AxisVector{4, 2, 1, 3, 0}}};
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackward_0, TSTestFixture, test_backward_gather(tests_arguments_bw[0]));
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackward_1, TSTestFixture, test_backward_gather(tests_arguments_bw[1]));
// In some cases shape of 2nd input to Gather op (indices) has `1` dims which can
// prevent TransposeSinking in backward direction.
@ -201,8 +215,9 @@ auto test_backward_gather_optimization = [](const GatherBackwardArguments& test_
};
vector<GatherBackwardArguments> tests_arguments_bw_optimization{
{{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})}},
constant<int>(i32, {1}, {1})}};
{{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})},
constant<int>(i32, {1}, {1}),
AxisVector{}}}};
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_0,
TSTestFixture,