Fix TransposeSinking for Gather (#19202)
* Fix TS gather * enable pytest * revert auto replaced comment
This commit is contained in:
parent
13f8ff4a40
commit
8509737d0a
@ -240,12 +240,14 @@ TSGatherBackward::TSGatherBackward() {
|
|||||||
if (success) {
|
if (success) {
|
||||||
size_t j = 0;
|
size_t j = 0;
|
||||||
for (size_t i = 0; i < shape.size(); ++i) {
|
for (size_t i = 0; i < shape.size(); ++i) {
|
||||||
if (shape[i] != new_shape[j] && shape[i] == 1) {
|
if (j >= new_shape.size() || shape[i] != new_shape[j]) {
|
||||||
axes_val.push_back(i);
|
if (shape[i] == 1) {
|
||||||
continue;
|
axes_val.push_back(i);
|
||||||
} else if (shape[i] != new_shape[j]) {
|
continue;
|
||||||
success = false;
|
} else {
|
||||||
break;
|
success = false;
|
||||||
|
break;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
j++;
|
j++;
|
||||||
}
|
}
|
||||||
|
@ -125,8 +125,9 @@ INSTANTIATE_TEST_SUITE_P(TSCommonGatherForward_3, TSTestFixture, test_forward_ga
|
|||||||
|
|
||||||
struct GatherBackwardArguments {
|
struct GatherBackwardArguments {
|
||||||
OutputVector inputs_to_main;
|
OutputVector inputs_to_main;
|
||||||
Output<Node> new_Gather_first_input;
|
Output<Node> ref_Gather_axis_input;
|
||||||
AxisVector new_transpose_order;
|
AxisVector ref_transpose_order;
|
||||||
|
AxisVector ref_unsqueeze_axes;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
|
auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
|
||||||
@ -147,14 +148,14 @@ auto test_backward_gather = [](const GatherBackwardArguments& test_arguments) {
|
|||||||
OutputVector new_out_vec(out_vec.size());
|
OutputVector new_out_vec(out_vec.size());
|
||||||
new_out_vec[0] = out_vec[0];
|
new_out_vec[0] = out_vec[0];
|
||||||
new_out_vec[1] = out_vec[1];
|
new_out_vec[1] = out_vec[1];
|
||||||
new_out_vec[2] = test_arguments.new_Gather_first_input;
|
new_out_vec[2] = test_arguments.ref_Gather_axis_input;
|
||||||
return new_out_vec;
|
return new_out_vec;
|
||||||
};
|
};
|
||||||
auto new_transpose = [&test_arguments](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
auto new_transpose = [&test_arguments](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||||
OutputVector new_out_vec = out_vec;
|
OutputVector new_out_vec = out_vec;
|
||||||
auto order = make_shared<Constant>(i32,
|
auto order = make_shared<Constant>(i32,
|
||||||
Shape{test_arguments.new_transpose_order.size()},
|
Shape{test_arguments.ref_transpose_order.size()},
|
||||||
test_arguments.new_transpose_order);
|
test_arguments.ref_transpose_order);
|
||||||
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
|
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
|
||||||
return new_out_vec;
|
return new_out_vec;
|
||||||
};
|
};
|
||||||
@ -197,13 +198,14 @@ auto test_backward_gather_optimization = [](const GatherBackwardArguments& test_
|
|||||||
OutputVector new_out_vec(out_vec.size());
|
OutputVector new_out_vec(out_vec.size());
|
||||||
new_out_vec[0] = out_vec[0];
|
new_out_vec[0] = out_vec[0];
|
||||||
new_out_vec[1] = make_shared<Squeeze>(out_vec[1]);
|
new_out_vec[1] = make_shared<Squeeze>(out_vec[1]);
|
||||||
new_out_vec[2] = test_arguments.new_Gather_first_input;
|
new_out_vec[2] = test_arguments.ref_Gather_axis_input;
|
||||||
return new_out_vec;
|
return new_out_vec;
|
||||||
};
|
};
|
||||||
|
|
||||||
auto unsqueeze_for = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
auto unsqueeze_for = [&](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
|
||||||
auto axis = constant<int>(i32, {1}, {0});
|
const auto& axes_val = test_arguments.ref_unsqueeze_axes;
|
||||||
return {make_shared<Unsqueeze>(out_vec[0], axis)};
|
auto axes = constant<size_t>(i32, {axes_val.size()}, axes_val);
|
||||||
|
return {make_shared<Unsqueeze>(out_vec[0], axes)};
|
||||||
};
|
};
|
||||||
|
|
||||||
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, update_gather_inputs}, {{0}, {1, 2}}};
|
test_case.model_ref.preprocess_inputs_to_main = {{set_transpose_for, update_gather_inputs}, {{0}, {1, 2}}};
|
||||||
@ -215,13 +217,29 @@ auto test_backward_gather_optimization = [](const GatherBackwardArguments& test_
|
|||||||
};
|
};
|
||||||
|
|
||||||
vector<GatherBackwardArguments> tests_arguments_bw_optimization{
|
vector<GatherBackwardArguments> tests_arguments_bw_optimization{
|
||||||
{{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})},
|
{{parameter(f32, {257, 8}), constant<int>(i32, {1, 2}, {0}), constant<int>(i32, {1}, {0})},
|
||||||
constant<int>(i32, {1}, {1}),
|
constant<int>(i32, {1}, {1}),
|
||||||
AxisVector{}}}};
|
AxisVector{},
|
||||||
|
AxisVector{0}},
|
||||||
|
{{parameter(f32, {4}), constant<int>(i32, {1}, {0}), constant<int>(i32, {1}, {0})},
|
||||||
|
constant<int>(i32, {1}, {0}),
|
||||||
|
AxisVector{},
|
||||||
|
AxisVector{0}},
|
||||||
|
{{parameter(f32, {4}), constant<int>(i32, {1, 1, 1}, {0}), constant<int>(i32, {1}, {0})},
|
||||||
|
constant<int>(i32, {1}, {0}),
|
||||||
|
AxisVector{},
|
||||||
|
AxisVector{0, 1, 2}},
|
||||||
|
};
|
||||||
|
|
||||||
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_0,
|
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_0,
|
||||||
TSTestFixture,
|
TSTestFixture,
|
||||||
test_backward_gather_optimization(tests_arguments_bw_optimization[0]));
|
test_backward_gather_optimization(tests_arguments_bw_optimization[0]));
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_1,
|
||||||
|
TSTestFixture,
|
||||||
|
test_backward_gather_optimization(tests_arguments_bw_optimization[1]));
|
||||||
|
INSTANTIATE_TEST_SUITE_P(TSCommonGatherBackwardOptimization_2,
|
||||||
|
TSTestFixture,
|
||||||
|
test_backward_gather_optimization(tests_arguments_bw_optimization[2]));
|
||||||
} // namespace gather
|
} // namespace gather
|
||||||
} // namespace testing
|
} // namespace testing
|
||||||
} // namespace transpose_sinking
|
} // namespace transpose_sinking
|
||||||
|
@ -55,11 +55,7 @@ class TestMaxPoolWithArgmax(CommonTFLayerTest):
|
|||||||
True, False
|
True, False
|
||||||
])
|
])
|
||||||
@pytest.mark.parametrize("with_second_output", [
|
@pytest.mark.parametrize("with_second_output", [
|
||||||
pytest.param(
|
True, False
|
||||||
True,
|
|
||||||
marks=pytest.mark.skip(reason="117415: TransposeSinking crash")
|
|
||||||
),
|
|
||||||
False
|
|
||||||
])
|
])
|
||||||
@pytest.mark.precommit_tf_fe
|
@pytest.mark.precommit_tf_fe
|
||||||
@pytest.mark.nightly
|
@pytest.mark.nightly
|
||||||
|
Loading…
Reference in New Issue
Block a user