resolve review comments

This commit is contained in:
Ivan 2023-03-23 17:47:50 +04:00
parent e4207c4d6b
commit 67c1b9daad
3 changed files with 78 additions and 14 deletions

View File

@ -30,7 +30,7 @@ TSSliceForward::TSSliceForward() {
auto& main_node = pattern_to_node.at(main_node_label);
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
if (!transpose) {
if (!transpose || main_node->get_input_size() < 5) {
return false;
}
@ -40,7 +40,7 @@ TSSliceForward::TSSliceForward() {
}
// remove Transpose on 1st input:
auto transpose_parent = main_node->input_value(0).get_node()->input_value(0);
auto transpose_parent = transpose->input_value(0);
main_node->input(0).replace_source_output(transpose_parent);
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
@ -85,6 +85,10 @@ TSSliceBackward::TSSliceBackward() {
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
if (main_node->get_input_size() < 5) {
return false;
}
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
transpose_const,
/* input_indexes= */ {0})) {

View File

@ -23,6 +23,14 @@ using namespace ov::pass::transpose_sinking::utils;
namespace {
/**
* @brief Checks that Reshape operation is equal to Squeeze:
* Only 1 dims are deleted, all other dims must be the same.
* Converts these 1 dims to axes format.
* @arg reshape Reshape operation.
* @arg reshape_to_shape 2nd input to Reshape op as a constant.
* @arg result_axes Contains axes which will be squeezed.
*/
bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
const std::shared_ptr<Constant>& reshape_to_shape,
std::vector<size_t>& result_axes) {
@ -61,10 +69,22 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
return true;
}
std::vector<size_t> squeeze_axes_to_shape(const Output<Node>& input_node, std::vector<size_t> squeeze_axes) {
std::vector<size_t> to_shape;
/**
* @brief Converts squeezed_axes to actual shape (2nd input) for Reshape operation
* using the shape of the 1st input to Reshape.
* @arg input_node 1st input to Reshape op.
* @arg squeeze_axes In case of Reshape op is equal to squeeze, these axes indicate the places where 1 dims have
* to be deleted.
*/
bool squeeze_axes_to_shape(const Output<Node>& input_node, std::vector<size_t> squeeze_axes,
std::vector<size_t>& to_shape) {
to_shape.clear();
std::sort(squeeze_axes.begin(), squeeze_axes.end());
const auto& input_shape = input_node.get_shape(); // check is static
const auto& input_pshape = input_node.get_partial_shape();
if (input_pshape.is_dynamic()) {
return false;
}
const auto& input_shape = input_pshape.get_shape();
for (size_t i = 0, j = 0; i < input_shape.size(); ++i) {
if (j < squeeze_axes.size() && i == squeeze_axes[j]) {
++j;
@ -72,7 +92,7 @@ std::vector<size_t> squeeze_axes_to_shape(const Output<Node>& input_node, std::v
}
to_shape.push_back(input_shape[i]);
}
return to_shape;
return true;
}
} // namespace
@ -133,7 +153,12 @@ TSSqueezeForward::TSSqueezeForward() {
transpose_order_values);
if (as_type_ptr<Reshape>(squeeze)) {
new_values = squeeze_axes_to_shape(transpose->input_value(0), new_values);
std::vector<size_t> to_shape;
auto success = squeeze_axes_to_shape(transpose->input_value(0), new_values, to_shape);
if (!success) {
return false;
}
new_values = to_shape;
}
auto new_const = Constant::create(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
@ -215,7 +240,12 @@ TSSqueezeBackward::TSSqueezeBackward() {
transpose_order_values);
auto new_transpose = transpose->clone_with_new_inputs({squeeze->input_value(0), new_transpose_order});
if (as_type_ptr<Reshape>(squeeze)) {
new_values = squeeze_axes_to_shape(new_transpose->output(0), new_values);
std::vector<size_t> to_shape;
auto success = squeeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
if (!success) {
return false;
}
new_values = to_shape;
}
std::shared_ptr<Node> new_squeeze;

View File

@ -23,6 +23,14 @@ using namespace ov::pass::transpose_sinking::utils;
namespace {
/**
* @brief Checks that Reshape operation is equal to Unsqueeze:
* Only 1 dims are inserted, all other dims must be the same.
* Converts these 1 dims to axes format.
* @arg reshape Reshape operation.
* @arg reshape_to_shape 2nd input to Reshape op as a constant.
* @arg result_axes contains axes which will be unsqueezed.
*/
bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
const std::shared_ptr<Constant>& reshape_to_shape,
std::vector<size_t>& result_axes) {
@ -60,9 +68,22 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
return true;
}
std::vector<size_t> unsqueeze_axes_to_shape(const Output<Node>& input_node, std::vector<size_t> unsqueeze_axes) {
const auto& input_shape = input_node.get_shape(); // check is static
std::vector<size_t> to_shape(input_shape.size() + unsqueeze_axes.size());
/**
* @brief Converts unsqueeze_axes to actual shape (2nd input) for Reshape operation
* using the shape of the 1st input to Reshape.
* @arg input_node 1st input to Reshape op.
* @arg unsqueeze_axes In case of Reshape op is equal to Unsqueeze, these axes indicate the places where 1 dims have
* to be inserted.
*/
bool unsqueeze_axes_to_shape(const Output<Node>& input_node, std::vector<size_t> unsqueeze_axes,
std::vector<size_t>& to_shape) {
to_shape.clear();
const auto& input_pshape = input_node.get_partial_shape();
if (input_pshape.is_dynamic()) {
return false;
}
const auto& input_shape = input_pshape.get_shape();
to_shape.resize(input_shape.size() + unsqueeze_axes.size());
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
for (size_t i = 0, j = 0, k = 0; i < to_shape.size(); ++i) {
if (j < unsqueeze_axes.size() && i == unsqueeze_axes[j]) {
@ -73,7 +94,7 @@ std::vector<size_t> unsqueeze_axes_to_shape(const Output<Node>& input_node, std:
k++;
}
}
return to_shape;
return true;
}
} // namespace
@ -114,7 +135,11 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
std::shared_ptr<Node> new_unsqueeze;
if (as_type_ptr<Reshape>(unsqueeze)) {
auto new_values = unsqueeze_axes_to_shape(transpose->input_value(0), non_negative_axes);
std::vector<size_t> new_values;
auto success = unsqueeze_axes_to_shape(transpose->input_value(0), non_negative_axes, new_values);
if (!success) {
return false;
}
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), {new_values.size()}, new_values);
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), new_const});
} else {
@ -194,7 +219,12 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
if (as_type_ptr<Reshape>(unsqueeze)) {
new_values = unsqueeze_axes_to_shape(new_transpose->output(0), new_values);
std::vector<size_t> to_shape;
auto success = unsqueeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
if (!success) {
return false;
}
new_values = to_shape;
}
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), unsqueeze_axes->get_shape(), new_values);
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({new_transpose, new_const});