Fix TransposeSinking for Gather (#19202)
* Fix TS gather * enable pytest * revert auto replaced comment
This commit is contained in:
parent
13f8ff4a40
commit
8509737d0a
@ -240,12 +240,14 @@ TSGatherBackward::TSGatherBackward() {
|
||||
if (success) {
|
||||
size_t j = 0;
|
||||
for (size_t i = 0; i < shape.size(); ++i) {
|
||||
if (shape[i] != new_shape[j] && shape[i] == 1) {
|
||||
axes_val.push_back(i);
|
||||
continue;
|
||||
} else if (shape[i] != new_shape[j]) {
|
||||
success = false;
|
||||
break;
|
||||
if (j >= new_shape.size() || shape[i] != new_shape[j]) {
|
||||
if (shape[i] == 1) {
|
||||
axes_val.push_back(i);
|
||||
continue;
|
||||
} else {
|
||||
success = false;
|
||||
break;
|
||||
}
|
||||
}
|
||||
j++;
|
||||
}
|
||||
|
@ -125,8 +125,9 @@ INSTANTIATE_TEST_SUITE_P(TSCommonGatherForward_3, TSTestFixture, test_forward_ga
|
||||
|
||||
struct GatherBackwardArguments {
|
||||
OutputVector inputs_to_main;
|
||||
Output<Node> new_Gather_first_input;
|
||||
AxisVector new_transpose_order;
|
||||
Output<Node> ref_Gather_axis_input;
|
||||
AxisVector ref_transpose_order;
|
||||
AxisVector ref_unsqueeze_axes;
|
||||
};
|
||||
|
||||
auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
|
||||
@ -147,14 +148,14 @@ auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] = out_vec[1];
|
||||
new_out_vec[2] = test_arguments.new_Gather_first_input;
|
||||
new_out_vec[2] = test_arguments.ref_Gather_axis_input;
|
||||
return new_out_vec;
|
||||
};
|
||||
auto new_transpose = [&test_arguments](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
OutputVector new_out_vec = out_vec;
|
||||
auto order = make_shared<Constant>(i32,
|
||||
Shape{test_arguments.new_transpose_order.size()},
|
||||
test_arguments.new_transpose_order);
|
||||
Shape{test_arguments.ref_transpose_order.size()},
|
||||
test_arguments.ref_transpose_order);
|
||||
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
|
||||
return new_out_vec;
|
||||
};
|
||||
@ -197,13 +198,14 @@ auto test_backward_gather_optimization = [](const GatherBackwardArguments& test_
|
||||
OutputVector new_out_vec(out_vec.size());
|
||||
new_out_vec[0] = out_vec[0];
|
||||
new_out_vec[1] = make_shared<Squeeze>(out_vec[1]);
|
||||
new_out_vec[2] = test_arguments.new_Gather_first_input;
|
||||
new_out_vec[2] = test_arguments.ref_Gather_axis_input;
|
||||
return new_out_vec;
|
||||
};
|
||||
|
||||
auto unsqueeze_for = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||
auto axis = constant<int>(i32, {1}, {0});
|
||||
return {make_shared<Unsqueeze>(out_vec[0], axis)};
|
||||
const auto& axes_val = test_arguments.ref_unsqueeze_axes;
|
||||
auto axes = constant<size_t>(i32, {axes_val.size()}, axes_val);
|
||||
return {make_shared<Unsqueeze>(out_vec[0], axes)};
|
||||
};
|
||||
|
||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, update_gather_inputs}, {{0}, {1, 2}}};
|
||||
@ -215,13 +217,29 @@ auto test_backward_gather_optimization = [](const GatherBackwardArguments& test_
|
||||
};
|
||||
|
||||
vector<GatherBackwardArguments> tests_arguments_bw_optimization{
|
||||
{{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})},
|
||||
constant<int>(i32, {1}, {1}),
|
||||
AxisVector{}}}};
|
||||
{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})},
|
||||
constant<int>(i32, {1}, {1}),
|
||||
AxisVector{},
|
||||
AxisVector{0}},
|
||||
{{parameter(f32, {4}), constant<int>(i32, {1}, {0}), constant<int>(i32, {1}, {0})},
|
||||
constant<int>(i32, {1}, {0}),
|
||||
AxisVector{},
|
||||
AxisVector{0}},
|
||||
{{parameter(f32, {4}), constant<int>(i32, {1, 1, 1}, {0}), constant<int>(i32, {1}, {0})},
|
||||
constant<int>(i32, {1}, {0}),
|
||||
AxisVector{},
|
||||
AxisVector{0, 1, 2}},
|
||||
};
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_0,
|
||||
TSTestFixture,
|
||||
test_backward_gather_optimization(tests_arguments_bw_optimization[0]));
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_1,
|
||||
TSTestFixture,
|
||||
test_backward_gather_optimization(tests_arguments_bw_optimization[1]));
|
||||
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_2,
|
||||
TSTestFixture,
|
||||
test_backward_gather_optimization(tests_arguments_bw_optimization[2]));
|
||||
} // namespace gather
|
||||
} // namespace testing
|
||||
} // namespace transpose_sinking
|
||||
|
@ -55,11 +55,7 @@ class TestMaxPoolWithArgmax(CommonTFLayerTest):
|
||||
True, False
|
||||
])
|
||||
@pytest.mark.parametrize("with_second_output", [
|
||||
pytest.param(
|
||||
True,
|
||||
marks=pytest.mark.skip(reason="117415: TransposeSinking crash")
|
||||
),
|
||||
False
|
||||
True, False
|
||||
])
|
||||
@pytest.mark.precommit_tf_fe
|
||||
@pytest.mark.nightly
|
||||
|
Loading…
Reference in New Issue
Block a user