Fix TransposeSinking for Gather (#19202)

* Fix TS gather

* enable pytest

* revert auto replaced comment
This commit is contained in:
Ivan Tikhonov 2023-08-15 20:23:01 +04:00 committed by GitHub
parent 13f8ff4a40
commit 8509737d0a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 38 additions and 22 deletions

View File

@ -240,12 +240,14 @@ TSGatherBackward::TSGatherBackward() {
if (success) { if (success) {
size_t j = 0; size_t j = 0;
for (size_t i = 0; i < shape.size(); ++i) { for (size_t i = 0; i < shape.size(); ++i) {
if (shape[i] != new_shape[j] && shape[i] == 1) { if (j >= new_shape.size() || shape[i] != new_shape[j]) {
axes_val.push_back(i); if (shape[i] == 1) {
continue; axes_val.push_back(i);
} else if (shape[i] != new_shape[j]) { continue;
success = false; } else {
break; success = false;
break;
}
} }
j++; j++;
} }

View File

@ -125,8 +125,9 @@ INSTANTIATE_TEST_SUITE_P(TSCommonGatherForward_3, TSTestFixture, test_forward_ga
struct GatherBackwardArguments { struct GatherBackwardArguments {
OutputVector inputs_to_main; OutputVector inputs_to_main;
Output<Node> new_Gather_first_input; Output<Node> ref_Gather_axis_input;
AxisVector new_transpose_order; AxisVector ref_transpose_order;
AxisVector ref_unsqueeze_axes;
}; };
auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) { auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
@ -147,14 +148,14 @@ auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
OutputVector new_out_vec(out_vec.size()); OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0]; new_out_vec[0] = out_vec[0];
new_out_vec[1] = out_vec[1]; new_out_vec[1] = out_vec[1];
new_out_vec[2] = test_arguments.new_Gather_first_input; new_out_vec[2] = test_arguments.ref_Gather_axis_input;
return new_out_vec; return new_out_vec;
}; };
auto new_transpose = [&test_arguments](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector { auto new_transpose = [&test_arguments](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec = out_vec; OutputVector new_out_vec = out_vec;
auto order = make_shared<Constant>(i32, auto order = make_shared<Constant>(i32,
Shape{test_arguments.new_transpose_order.size()}, Shape{test_arguments.ref_transpose_order.size()},
test_arguments.new_transpose_order); test_arguments.ref_transpose_order);
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order); new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
return new_out_vec; return new_out_vec;
}; };
@ -197,13 +198,14 @@ auto test_backward_gather_optimization = [](const GatherBackwardArguments& test_
OutputVector new_out_vec(out_vec.size()); OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0]; new_out_vec[0] = out_vec[0];
new_out_vec[1] = make_shared<Squeeze>(out_vec[1]); new_out_vec[1] = make_shared<Squeeze>(out_vec[1]);
new_out_vec[2] = test_arguments.new_Gather_first_input; new_out_vec[2] = test_arguments.ref_Gather_axis_input;
return new_out_vec; return new_out_vec;
}; };
auto unsqueeze_for = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector { auto unsqueeze_for = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
auto axis = constant<int>(i32, {1}, {0}); const auto& axes_val = test_arguments.ref_unsqueeze_axes;
return {make_shared<Unsqueeze>(out_vec[0], axis)}; auto axes = constant<size_t>(i32, {axes_val.size()}, axes_val);
return {make_shared<Unsqueeze>(out_vec[0], axes)};
}; };
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, update_gather_inputs}, {{0}, {1, 2}}}; test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, update_gather_inputs}, {{0}, {1, 2}}};
@ -215,13 +217,29 @@ auto test_backward_gather_optimization = [](const GatherBackwardArguments& test_
}; };
vector<GatherBackwardArguments> tests_arguments_bw_optimization{ vector<GatherBackwardArguments> tests_arguments_bw_optimization{
{{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})}, {{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})},
constant<int>(i32, {1}, {1}), constant<int>(i32, {1}, {1}),
AxisVector{}}}}; AxisVector{},
AxisVector{0}},
{{parameter(f32, {4}), constant<int>(i32, {1}, {0}), constant<int>(i32, {1}, {0})},
constant<int>(i32, {1}, {0}),
AxisVector{},
AxisVector{0}},
{{parameter(f32, {4}), constant<int>(i32, {1, 1, 1}, {0}), constant<int>(i32, {1}, {0})},
constant<int>(i32, {1}, {0}),
AxisVector{},
AxisVector{0, 1, 2}},
};
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_0, INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_0,
TSTestFixture, TSTestFixture,
test_backward_gather_optimization(tests_arguments_bw_optimization[0])); test_backward_gather_optimization(tests_arguments_bw_optimization[0]));
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_1,
TSTestFixture,
test_backward_gather_optimization(tests_arguments_bw_optimization[1]));
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_2,
TSTestFixture,
test_backward_gather_optimization(tests_arguments_bw_optimization[2]));
} // namespace gather } // namespace gather
} // namespace testing } // namespace testing
} // namespace transpose_sinking } // namespace transpose_sinking

View File

@ -55,11 +55,7 @@ class TestMaxPoolWithArgmax(CommonTFLayerTest):
True, False True, False
]) ])
@pytest.mark.parametrize("with_second_output", [ @pytest.mark.parametrize("with_second_output", [
pytest.param( True, False
True,
marks=pytest.mark.skip(reason="117415: TransposeSinking crash")
),
False
]) ])
@pytest.mark.precommit_tf_fe @pytest.mark.precommit_tf_fe
@pytest.mark.nightly @pytest.mark.nightly