resolve review comments
This commit is contained in:
parent
e4207c4d6b
commit
67c1b9daad
@ -30,7 +30,7 @@ TSSliceForward::TSSliceForward() {
|
||||
|
||||
auto& main_node = pattern_to_node.at(main_node_label);
|
||||
auto transpose = std::dynamic_pointer_cast<Transpose>(pattern_to_node.at(transpose_label));
|
||||
if (!transpose) {
|
||||
if (!transpose || main_node->get_input_size() < 5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
@ -40,7 +40,7 @@ TSSliceForward::TSSliceForward() {
|
||||
}
|
||||
|
||||
// remove Transpose on 1st input:
|
||||
auto transpose_parent = main_node->input_value(0).get_node()->input_value(0);
|
||||
auto transpose_parent = transpose->input_value(0);
|
||||
main_node->input(0).replace_source_output(transpose_parent);
|
||||
|
||||
const auto transpose_axis_order = transpose_const->get_axis_vector_val();
|
||||
@ -85,6 +85,10 @@ TSSliceBackward::TSSliceBackward() {
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto main_node = pattern_to_output.at(main_node_label).get_node_shared_ptr();
|
||||
|
||||
if (main_node->get_input_size() < 5) {
|
||||
return false;
|
||||
}
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(main_node,
|
||||
transpose_const,
|
||||
/* input_indexes= */ {0})) {
|
||||
|
@ -23,6 +23,14 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* @brief Checks that Reshape operation is equal to Squeeze:
|
||||
* Only 1 dims are deleted, all other dims must be the same.
|
||||
* Converts these 1 dims to axes format.
|
||||
* @arg reshape Reshape operation.
|
||||
* @arg reshape_to_shape 2nd input to Reshape op as a constant.
|
||||
* @arg result_axes Contains axes which will be squeezed.
|
||||
*/
|
||||
bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
const std::shared_ptr<Constant>& reshape_to_shape,
|
||||
std::vector<size_t>& result_axes) {
|
||||
@ -61,10 +69,22 @@ bool shape_to_squeeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<size_t> squeeze_axes_to_shape(const Output<Node>& input_node, std::vector<size_t> squeeze_axes) {
|
||||
std::vector<size_t> to_shape;
|
||||
/**
|
||||
* @brief Converts squeezed_axes to actual shape (2nd input) for Reshape operation
|
||||
* using the shape of the 1st input to Reshape.
|
||||
* @arg input_node 1st input to Reshape op.
|
||||
* @arg squeeze_axes In case of Reshape op is equal to squeeze, these axes indicate the places where 1 dims have
|
||||
* to be deleted.
|
||||
*/
|
||||
bool squeeze_axes_to_shape(const Output<Node>& input_node, std::vector<size_t> squeeze_axes,
|
||||
std::vector<size_t>& to_shape) {
|
||||
to_shape.clear();
|
||||
std::sort(squeeze_axes.begin(), squeeze_axes.end());
|
||||
const auto& input_shape = input_node.get_shape(); // check is static
|
||||
const auto& input_pshape = input_node.get_partial_shape();
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
const auto& input_shape = input_pshape.get_shape();
|
||||
for (size_t i = 0, j = 0; i < input_shape.size(); ++i) {
|
||||
if (j < squeeze_axes.size() && i == squeeze_axes[j]) {
|
||||
++j;
|
||||
@ -72,7 +92,7 @@ std::vector<size_t> squeeze_axes_to_shape(const Output<Node>& input_node, std::v
|
||||
}
|
||||
to_shape.push_back(input_shape[i]);
|
||||
}
|
||||
return to_shape;
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
@ -133,7 +153,12 @@ TSSqueezeForward::TSSqueezeForward() {
|
||||
transpose_order_values);
|
||||
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
new_values = squeeze_axes_to_shape(transpose->input_value(0), new_values);
|
||||
std::vector<size_t> to_shape;
|
||||
auto success = squeeze_axes_to_shape(transpose->input_value(0), new_values, to_shape);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
new_values = to_shape;
|
||||
}
|
||||
|
||||
auto new_const = Constant::create(squeeze_axes->get_element_type(), squeeze_axes->get_shape(), new_values);
|
||||
@ -215,7 +240,12 @@ TSSqueezeBackward::TSSqueezeBackward() {
|
||||
transpose_order_values);
|
||||
auto new_transpose = transpose->clone_with_new_inputs({squeeze->input_value(0), new_transpose_order});
|
||||
if (as_type_ptr<Reshape>(squeeze)) {
|
||||
new_values = squeeze_axes_to_shape(new_transpose->output(0), new_values);
|
||||
std::vector<size_t> to_shape;
|
||||
auto success = squeeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
new_values = to_shape;
|
||||
}
|
||||
|
||||
std::shared_ptr<Node> new_squeeze;
|
||||
|
@ -23,6 +23,14 @@ using namespace ov::pass::transpose_sinking::utils;
|
||||
|
||||
namespace {
|
||||
|
||||
/**
|
||||
* @brief Checks that Reshape operation is equal to Unsqueeze:
|
||||
* Only 1 dims are inserted, all other dims must be the same.
|
||||
* Converts these 1 dims to axes format.
|
||||
* @arg reshape Reshape operation.
|
||||
* @arg reshape_to_shape 2nd input to Reshape op as a constant.
|
||||
* @arg result_axes contains axes which will be unsqueezed.
|
||||
*/
|
||||
bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
const std::shared_ptr<Constant>& reshape_to_shape,
|
||||
std::vector<size_t>& result_axes) {
|
||||
@ -60,9 +68,22 @@ bool shape_to_unsqueeze_axes(const std::shared_ptr<Node>& reshape,
|
||||
return true;
|
||||
}
|
||||
|
||||
std::vector<size_t> unsqueeze_axes_to_shape(const Output<Node>& input_node, std::vector<size_t> unsqueeze_axes) {
|
||||
const auto& input_shape = input_node.get_shape(); // check is static
|
||||
std::vector<size_t> to_shape(input_shape.size() + unsqueeze_axes.size());
|
||||
/**
|
||||
* @brief Converts unsqueeze_axes to actual shape (2nd input) for Reshape operation
|
||||
* using the shape of the 1st input to Reshape.
|
||||
* @arg input_node 1st input to Reshape op.
|
||||
* @arg unsqueeze_axes In case of Reshape op is equal to Unsqueeze, these axes indicate the places where 1 dims have
|
||||
* to be inserted.
|
||||
*/
|
||||
bool unsqueeze_axes_to_shape(const Output<Node>& input_node, std::vector<size_t> unsqueeze_axes,
|
||||
std::vector<size_t>& to_shape) {
|
||||
to_shape.clear();
|
||||
const auto& input_pshape = input_node.get_partial_shape();
|
||||
if (input_pshape.is_dynamic()) {
|
||||
return false;
|
||||
}
|
||||
const auto& input_shape = input_pshape.get_shape();
|
||||
to_shape.resize(input_shape.size() + unsqueeze_axes.size());
|
||||
std::sort(unsqueeze_axes.begin(), unsqueeze_axes.end());
|
||||
for (size_t i = 0, j = 0, k = 0; i < to_shape.size(); ++i) {
|
||||
if (j < unsqueeze_axes.size() && i == unsqueeze_axes[j]) {
|
||||
@ -73,7 +94,7 @@ std::vector<size_t> unsqueeze_axes_to_shape(const Output<Node>& input_node, std:
|
||||
k++;
|
||||
}
|
||||
}
|
||||
return to_shape;
|
||||
return true;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
@ -114,7 +135,11 @@ TSUnsqueezeForward::TSUnsqueezeForward() {
|
||||
|
||||
std::shared_ptr<Node> new_unsqueeze;
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
auto new_values = unsqueeze_axes_to_shape(transpose->input_value(0), non_negative_axes);
|
||||
std::vector<size_t> new_values;
|
||||
auto success = unsqueeze_axes_to_shape(transpose->input_value(0), non_negative_axes, new_values);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), {new_values.size()}, new_values);
|
||||
new_unsqueeze = unsqueeze->clone_with_new_inputs({transpose->input_value(0), new_const});
|
||||
} else {
|
||||
@ -194,7 +219,12 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
|
||||
|
||||
auto new_transpose = transpose->clone_with_new_inputs({unsqueeze->input_value(0), new_transpose_order});
|
||||
if (as_type_ptr<Reshape>(unsqueeze)) {
|
||||
new_values = unsqueeze_axes_to_shape(new_transpose->output(0), new_values);
|
||||
std::vector<size_t> to_shape;
|
||||
auto success = unsqueeze_axes_to_shape(new_transpose->output(0), new_values, to_shape);
|
||||
if (!success) {
|
||||
return false;
|
||||
}
|
||||
new_values = to_shape;
|
||||
}
|
||||
auto new_const = Constant::create(unsqueeze_axes->get_element_type(), unsqueeze_axes->get_shape(), new_values);
|
||||
auto new_unsqueeze = unsqueeze->clone_with_new_inputs({new_transpose, new_const});
|
||||
|
Loading…
Reference in New Issue
Block a user