fix unit tests, revert changes for TSSlice transformation

This commit is contained in:
Ivan 2023-03-21 20:21:42 +04:00
parent 0f17c5f714
commit e4207c4d6b
2 changed files with 29 additions and 5 deletions

View File

@ -46,8 +46,11 @@ TSSliceForward::TSSliceForward() {
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
main_node->input(4).replace_source_output(
ChangeValuesOrder(main_node->input_value(4), transpose_axis_order, axis));
auto data = std::make_shared<Constant>(element::i32, Shape{transpose_axis_order.size()}, transpose_axis_order);
const auto& indices = main_node->input_value(4);
auto new_axis = std::make_shared<Gather>(data, indices, axis);
main_node->input(4).replace_source_output(new_axis);
main_node->validate_and_infer_types();
TransposeInputsInfo transpose_input_info = {transpose, transpose_const, 0};
@ -87,14 +90,19 @@ TSSliceBackward::TSSliceBackward() {
/* input_indexes= */ {0})) {
register_new_node(new_node);
}
// remove output transposes
RemoveSingleOutputConsumers(main_node);
SwapNames(main_node, transpose);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
const auto reversed_transpose_order = ReverseTransposeOrder(transpose_axis_order);
auto axis = std::make_shared<Constant>(element::i32, Shape{}, std::vector<int32_t>{0});
main_node->input(4).replace_source_output(
ChangeValuesOrder(main_node->input_value(4), reversed_transpose_order, axis));
auto data =
std::make_shared<Constant>(element::i32, Shape{reversed_transpose_order.size()}, reversed_transpose_order);
const auto& indices = main_node->input_value(4);
auto new_axis = std::make_shared<Gather>(data, indices, axis);
main_node->input(4).replace_source_output(new_axis);
main_node->validate_and_infer_types();
return true;
};

View File

@ -862,6 +862,14 @@ auto test_forward_reshape_unsqueeze = []() {
new_out_vec[0] = make_shared<Transpose>(out_vec[0], order);
return new_out_vec;
};
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] =
make_shared<Constant>(out_vec[1].get_element_type(), out_vec[1].get_shape(), std::vector<int64_t>{6, 1, 5, 1, 4});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_constant}, {{1}}};
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
test_case.model_ref.preprocess_outputs_of_main = {{new_transpose}, {{0}}};
test_case.model_ref.model_template = create_model;
@ -1283,7 +1291,15 @@ auto test_backward_reshape_squeeze = []() {
new_out_vec[1] = out_vec[1];
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose}, {{0}}};
auto new_constant = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
OutputVector new_out_vec(out_vec.size());
new_out_vec[0] = out_vec[0];
new_out_vec[1] = make_shared<Constant>(out_vec[1].get_element_type(),
out_vec[1].get_shape(),
std::vector<int64_t>{6, 5, 4});
return new_out_vec;
};
test_case.model_ref.preprocess_inputs_to_main = {{new_transpose, new_constant}, {{0}, {1}}};
test_case.model_ref.main_op = {CREATE_RESHAPE_FACTORY(Reshape)};
test_case.model_ref.model_template = create_model;