fix (#18762)
This commit is contained in:
parent
7767af3529
commit
4d3601ac11
@ -262,31 +262,42 @@ TSGatherBackward::TSGatherBackward() {
|
||||
order_val = GetOrderAfterReduction(axes_val, order_val);
|
||||
}
|
||||
|
||||
std::shared_ptr<ov::op::v0::Constant> new_axis;
|
||||
const auto& indices_rank_val = static_cast<size_t>(main_node->get_input_partial_shape(1).rank().get_length());
|
||||
std::vector<size_t> new_transpose_order(order_val.size() - indices_rank_val + 1);
|
||||
for (size_t i = 0, j = 0; i < order_val.size(); ++j) {
|
||||
if (order_val[i] < axis) {
|
||||
new_transpose_order[j] = order_val[i];
|
||||
++i;
|
||||
} else if (order_val[i] > axis) {
|
||||
new_transpose_order[j] = order_val[i] - indices_rank_val + 1;
|
||||
++i;
|
||||
} else {
|
||||
// the next `indices_rank_val` values have to be in ascending order
|
||||
// these values will be replaced with a single axis
|
||||
new_transpose_order[j] = order_val[i];
|
||||
size_t prev_idx = i;
|
||||
for (size_t k = 0; i < order_val.size() && k < indices_rank_val; ++i, ++k) {
|
||||
if (order_val[i] != order_val[prev_idx]) {
|
||||
if (success && squeeze) {
|
||||
main_node->input(1).replace_source_output(squeeze->input_value(0));
|
||||
|
||||
std::vector<size_t> new_transpose_order;
|
||||
if (indices_rank_val > 0) {
|
||||
new_transpose_order.resize(order_val.size() - indices_rank_val + 1);
|
||||
|
||||
for (size_t i = 0, j = 0; i < order_val.size(); ++j) {
|
||||
if (order_val[i] < axis) {
|
||||
new_transpose_order[j] = order_val[i];
|
||||
++i;
|
||||
} else if (order_val[i] > axis) {
|
||||
new_transpose_order[j] = order_val[i] - indices_rank_val + 1;
|
||||
++i;
|
||||
} else {
|
||||
// the next `indices_rank_val` values have to be in ascending order
|
||||
// these values will be replaced with a single axis
|
||||
new_transpose_order[j] = order_val[i];
|
||||
size_t prev_idx = i;
|
||||
for (size_t k = 0; i < order_val.size() && k < indices_rank_val; ++i, ++k) {
|
||||
if (order_val[i] != order_val[prev_idx]) {
|
||||
if (success && squeeze) {
|
||||
main_node->input(1).replace_source_output(squeeze->input_value(0));
|
||||
}
|
||||
return false;
|
||||
}
|
||||
return false;
|
||||
prev_idx = i;
|
||||
}
|
||||
prev_idx = i;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
const std::vector<size_t> axes_values = {axis};
|
||||
new_transpose_order = GetOrderBeforeReduction(axes_values, order_val);
|
||||
new_axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{1}, axis);
|
||||
}
|
||||
|
||||
RemoveTransposeConsumers(main_node);
|
||||
if (success) {
|
||||
auto target_inputs = main_node->get_output_target_inputs(0);
|
||||
@ -310,7 +321,9 @@ TSGatherBackward::TSGatherBackward() {
|
||||
/* input_indexes= */ {0})) {
|
||||
register_new_node(new_node);
|
||||
}
|
||||
auto new_axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{1}, reversed_transpose_order[axis]);
|
||||
if (!new_axis) {
|
||||
new_axis = std::make_shared<ov::op::v0::Constant>(element::i32, Shape{1}, reversed_transpose_order[axis]);
|
||||
}
|
||||
copy_runtime_info(gather_axis, new_axis);
|
||||
main_node->input(2).replace_source_output(new_axis);
|
||||
main_node->validate_and_infer_types();
|
||||
|
@ -126,6 +126,7 @@ INSTANTIATE_TEST_SUITE_P(TSCommonGatherForward_3, TSTestFixture, test_forward_ga
|
||||
struct GatherBackwardArguments {
|
||||
OutputVector inputs_to_main;
|
||||
Output<Node> new_Gather_first_input;
|
||||
AxisVector new_transpose_order;
|
||||
};
|
||||
|
||||
auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
|
||||
@ -149,7 +150,15 @@ auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
|
||||
new_out_vec[2] = test_arguments.new_Gather_first_input;
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, new_constant}, {{0}, {2}}};
|
||||
auto new_transpose = [&test_arguments](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec = out_vec;
|
||||
auto order = make_shared<Constant>(i32,
|
||||
Shape{test_arguments.new_transpose_order.size()},
|
||||
test_arguments.new_transpose_order);
|
||||
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
|
||||
return new_out_vec;
|
||||
};
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose, new_constant}, {{0}, {2}}};
|
||||
test_case.model_ref.main_op = {CREATE_GATHER_FACTORY(Gather)};
|
||||
test_case.model_ref.model_template = create_model;
|
||||
|
||||
@ -157,10 +166,15 @@ auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
|
||||
};
|
||||
|
||||
vector<GatherBackwardArguments> tests_arguments_bw{
|
||||
{{{parameter(f32, {3, 4, 5, 6}), constant<int>(i32, {2}, {0, 2}), constant<int>(i32, {1}, {2})}},
|
||||
constant<int>(i32, {1}, {1})}};
|
||||
{{parameter(f32, {3, 4, 5, 6}), constant<int>(i32, {2}, {0, 2}), constant<int>(i32, {1}, {2})},
|
||||
constant<int>(i32, {1}, {1}),
|
||||
AxisVector{3, 2, 1, 0}},
|
||||
{{parameter(f32, {1, 2, 16, 3, 64}), constant<int>(i32, {}, {0}), constant<int>(i32, {1}, {3})},
|
||||
constant<int>(i32, {1}, {3}),
|
||||
AxisVector{4, 2, 1, 3, 0}}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackward_0, TSTestFixture, test_backward_gather(tests_arguments_bw[0]));
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackward_1, TSTestFixture, test_backward_gather(tests_arguments_bw[1]));
|
||||
|
||||
// In some cases shape of 2nd input to Gather op (indices) has `1` dims which can
|
||||
// prevent TransposeSinking in backward direction.
|
||||
@ -201,8 +215,9 @@ auto test_backward_gather_optimization = [](const GatherBackwardArguments& test_
|
||||
};
|
||||
|
||||
vector<GatherBackwardArguments> tests_arguments_bw_optimization{
|
||||
{{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})}},
|
||||
constant<int>(i32, {1}, {1})}};
|
||||
{{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})},
|
||||
constant<int>(i32, {1}, {1}),
|
||||
AxisVector{}}}};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_0,
|
||||
TSTestFixture,
|
||||
|
Loading…
Reference in New Issue
Block a user