[GNA] Added Reshape support to TS patterns (#18391)

* Added Reshape support to TS patterns

* fixed unit test

* replaced duplicated code
This commit is contained in:
Mikhail Ryzhov
2023-07-20 12:08:02 +02:00
committed by GitHub
parent 9254c74362
commit 3b8f58f1af
8 changed files with 69 additions and 24 deletions

View File

@@ -81,6 +81,13 @@ inline bool get_constant_value(const std::shared_ptr<ngraph::opset8::Constant>&
return true;
}
/**
* @brief Checks if 2 shapes are the same
*/
inline bool are_shapes_equal(const ov::Shape& shape_1, const ov::Shape& shape_2) {
return (shape_1.size() == shape_2.size()) && std::equal(shape_1.begin(), shape_1.end(), shape_2.begin());
}
inline bool is_aligned_split(const std::shared_ptr<ngraph::Node> input_op, size_t input_op_out_index) {
size_t offset = 0;
@@ -609,8 +616,7 @@ inline bool is_reshape_unsqueeze(const ov::Output<ov::Node>& output) {
auto reshape = output.get_node_shared_ptr();
const ov::Shape input_shape = trim_shape(reshape->get_input_shape(0));
const ov::Shape output_shape = trim_shape(reshape->get_output_shape(0));
return (input_shape.size() == output_shape.size()) &&
std::equal(input_shape.begin(), input_shape.end(), output_shape.begin());
return are_shapes_equal(input_shape, output_shape);
}
/**

View File

@@ -139,12 +139,12 @@ void TransformationsPipeline::apply(const std::shared_ptr<ov::Model>& model,
manager.register_pass<ov::intel_gna::pass::ReplaceGnaNHWCLayers>();
manager.register_pass<ov::intel_gna::pass::InsertConvolutionTransposeHW>();
manager.register_pass<ov::pass::TransposeSinkingGeneral>();
manager.register_pass<ov::intel_gna::pass::TSConcatForward>();
manager.register_pass<ov::intel_gna::pass::TSSplitBackward>();
manager.register_pass<ov::intel_gna::pass::GatherSinkingGeneral>();
manager.register_pass<ov::pass::ReshapeSequenceFusion>();
manager.register_pass<ov::pass::TransposeToReshape>();
manager.register_pass<ov::intel_gna::pass::GnaConvolutionFusion>();
manager.register_pass<ov::intel_gna::pass::TSConcatForward>();
manager.register_pass<ov::intel_gna::pass::TSSplitBackward>();
manager.register_pass<ov::pass::transpose_sinking::TSFuse>();
}
manager.register_pass<ov::intel_gna::pass::RemoveInputsProcessing>(input_output_subgraphs);

View File

@@ -8,6 +8,7 @@
#include <transformations/utils/utils.hpp>
#include <utility>
#include "common/graph_utils.hpp"
#include "openvino/opsets/opset12.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
@@ -15,6 +16,7 @@ using namespace ov;
using namespace ov::opset12;
using namespace ov::pass::pattern;
using namespace ov::op::util;
using namespace ov::intel_gna;
using namespace ov::intel_gna::pass;
namespace {
@@ -91,8 +93,7 @@ AxisVector find_suitable_transpose_order(const Shape& input_shape,
const std::vector<AxisVector>& orders) {
for (const auto& order : orders) {
const Shape transposed_shape = apply_permutation(input_shape, order);
if ((transposed_shape.size() == output_shape.size()) &&
std::equal(transposed_shape.begin(), transposed_shape.end(), output_shape.begin()))
if (graph_utils::are_shapes_equal(transposed_shape, output_shape))
return order;
}

View File

@@ -23,6 +23,7 @@ using namespace ov::intel_gna::pass::helper;
using namespace ov::intel_gna::limitations;
namespace {
bool is_concat_sinked(const Output<Node>& output) {
auto concat_node = ov::as_type_ptr<Concat>(output.get_node_shared_ptr());
@@ -33,13 +34,18 @@ bool is_concat_sinked(const Output<Node>& output) {
for (size_t i = 0; i < concat_node->get_input_size(); ++i) {
auto concat_input = concat_node->input_value(i);
auto transpose = ov::as_type_ptr<Transpose>(concat_input.get_node_shared_ptr());
auto target_node = graph_utils::get_prev_node_skipping_certain(concat_input.get_node_shared_ptr(),
graph_utils::is_gna_non_functional_node);
std::shared_ptr<ov::Node> transpose = ov::as_type_ptr<Transpose>(target_node);
if (transpose && !Limitations::is_transpose_supported(transpose))
return true;
}
return false;
}
} // namespace
TSConcatForward::TSConcatForward() {
@@ -57,7 +63,10 @@ TSConcatForward::TSConcatForward() {
OutputVector concat_inputs = {};
for (size_t i = 0; i < concat_node->get_input_size(); ++i) {
ov::Output<ov::Node> concat_input = concat_node->input_value(i);
std::shared_ptr<ov::Node> transpose = ov::as_type_ptr<Transpose>(concat_input.get_node_shared_ptr());
auto target_node = graph_utils::get_prev_node_skipping_certain(concat_input.get_node_shared_ptr(),
graph_utils::is_gna_non_functional_node);
std::shared_ptr<ov::Node> transpose = ov::as_type_ptr<Transpose>(target_node);
ov::Shape transpose_shape = concat_input.get_shape();
ov::AxisVector transpose_order(transpose_shape.size());
@@ -66,7 +75,7 @@ TSConcatForward::TSConcatForward() {
ov::as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
transpose_order = transpose_const->get_axis_vector_val();
transpose_shape = transpose->get_input_shape(0);
concat_input = concat_input.get_node_shared_ptr()->input_value(0);
concat_input = transpose->input_value(0);
} else {
std::iota(transpose_order.begin(), transpose_order.end(), 0);
}
@@ -102,7 +111,13 @@ TSConcatForward::TSConcatForward() {
std::make_shared<Constant>(ov::element::i64, ov::Shape{concat_shape_out.size()}, concat_shape_out);
auto reshape_output = std::make_shared<Reshape>(gather_node, reshape_output_const, false);
ov::replace_node_update_name(concat_node, reshape_output);
// skip reshape if the input and output shapes are the same
if (graph_utils::are_shapes_equal(reshape_output->get_input_shape(0), reshape_output->get_output_shape(0))) {
ov::replace_node_update_name(concat_node, gather_node);
} else {
ov::replace_node_update_name(concat_node, reshape_output);
}
return true;
};

View File

@@ -18,17 +18,21 @@
using namespace ov;
using namespace ov::opset12;
using namespace ov::pass::pattern;
using namespace ov::intel_gna::graph_utils;
using namespace ov::intel_gna::pass;
using namespace ov::intel_gna::pass::helper;
using namespace ov::intel_gna::limitations;
using namespace ov::intel_gna::graph_utils;
namespace {
bool is_split_sinked(const Output<Node>& output) {
auto split_node = output.get_node_shared_ptr();
for (size_t output_idx = 0; output_idx < split_node->get_output_size(); ++output_idx) {
for (auto& input : split_node->get_output_target_inputs(output_idx)) {
auto transpose = ov::as_type_ptr<Transpose>(input.get_node()->shared_from_this());
auto target_node =
get_next_node_skipping_certain(input.get_node()->shared_from_this(), is_gna_non_functional_node);
std::shared_ptr<ov::Node> transpose = ov::as_type_ptr<Transpose>(target_node);
if (transpose && !Limitations::is_transpose_supported(transpose))
return true;
}
@@ -40,19 +44,25 @@ bool is_split_sinked(const Output<Node>& output) {
TSSplitBackward::TSSplitBackward() {
MATCHER_SCOPE(TSSplitBackward);
auto split_node_label = wrap_type<Split>(is_split_sinked);
auto split_node_label = wrap_type<Split, VariadicSplit>(is_split_sinked);
matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto& split_node_label_output = pattern_to_output.at(split_node_label);
auto split_node = as_type_ptr<Split>(split_node_label_output.get_node_shared_ptr());
std::shared_ptr<ov::Node> split_node = as_type_ptr<Split>(split_node_label_output.get_node_shared_ptr());
if (!split_node) {
split_node = as_type_ptr<VariadicSplit>(split_node_label_output.get_node_shared_ptr());
}
ov::AxisVector gather_ids = {};
std::vector<AxisVector> gather_indices_vecs;
std::vector<size_t> split_slices;
for (size_t output_idx = 0; output_idx < split_node->get_output_size(); ++output_idx) {
for (auto& input : split_node->get_output_target_inputs(output_idx)) {
auto transpose = ov::as_type_ptr<Transpose>(input.get_node()->shared_from_this());
auto target_node =
get_next_node_skipping_certain(input.get_node()->shared_from_this(), is_gna_non_functional_node);
std::shared_ptr<ov::Node> transpose = ov::as_type_ptr<Transpose>(target_node);
ov::Shape transpose_shape = split_node->get_output_shape(output_idx);
ov::AxisVector transpose_order(transpose_shape.size());
@@ -74,6 +84,8 @@ TSSplitBackward::TSSplitBackward() {
i += id;
});
gather_ids.insert(gather_ids.end(), slice_ids.begin(), slice_ids.end());
// collect slice sizes
split_slices.push_back(slice_ids.size());
}
}
@@ -89,18 +101,27 @@ TSSplitBackward::TSSplitBackward() {
ov::copy_runtime_info(split_node, {reshape_input, reshape_input_const});
std::shared_ptr<ov::Node> gather;
auto gather_axis = std::make_shared<Constant>(ov::element::i64, ov::Shape{}, 1);
auto gather_indices = std::make_shared<Constant>(ov::element::i64, ov::Shape{gather_ids.size()}, gather_ids);
auto gather = std::make_shared<Gather>(reshape_input, gather_indices, gather_axis);
if (graph_utils::are_shapes_equal(split_input_shape, reshape_input_shape)) {
gather = std::make_shared<Gather>(split_node->input_value(0), gather_indices, gather_axis);
} else {
gather = std::make_shared<Gather>(reshape_input, gather_indices, gather_axis);
}
auto split_axis_new = std::make_shared<Constant>(ov::element::i64, ov::Shape{}, 1);
auto split_new = std::make_shared<Split>(gather, split_axis_new, split_node->get_num_splits());
auto split_new_axis = std::make_shared<Constant>(ov::element::i64, ov::Shape{}, 1);
auto split_new_lengths =
std::make_shared<Constant>(ov::element::i64, ov::Shape{split_slices.size()}, split_slices);
auto split_new = std::make_shared<VariadicSplit>(gather, split_new_axis, split_new_lengths);
ov::copy_runtime_info(split_node, {gather_axis, gather_indices, gather, split_axis_new, split_new});
ov::copy_runtime_info(split_node, {gather_axis, gather_indices, gather, split_new_axis, split_new});
for (size_t output_idx = 0; output_idx < split_node->get_output_size(); ++output_idx) {
for (auto& input : split_node->get_output_target_inputs(output_idx)) {
auto transpose = ov::as_type_ptr<Transpose>(input.get_node()->shared_from_this());
auto target_node = get_next_node_skipping_certain(input.get_node()->shared_from_this(),
graph_utils::is_gna_non_functional_node);
std::shared_ptr<ov::Node> transpose = ov::as_type_ptr<Transpose>(target_node);
if (transpose) {
auto reshape_output_const_new =
std::make_shared<Constant>(ov::element::i64,

View File

@@ -257,7 +257,7 @@ struct GatherInfo {
bool operator==(const GatherInfo& another) {
if (indices.size() != another.indices.size())
return false;
if (!std::equal(indices.begin(), indices.end(), another.indices.begin()))
if (!are_shapes_equal(indices, another.indices))
return false;
return axis == another.axis;
}

View File

@@ -4,6 +4,7 @@
#include "transformation_helper.hpp"
#include "common/graph_utils.hpp"
#include "log/debug.hpp"
#include "openvino/core/rt_info.hpp"
#include "openvino/opsets/opset12.hpp"
@@ -13,6 +14,7 @@
#include "transformations/rt_info/transpose_sinking_attr.hpp"
using namespace ov::opset12;
using namespace ov::intel_gna;
namespace ov {
namespace intel_gna {
@@ -92,7 +94,7 @@ bool TransposeOrderMatches(std::shared_ptr<Transpose> transpose, std::vector<siz
if (data.empty())
return false;
if (order.size() != data.size() || !std::equal(order.begin(), order.end(), data.begin()))
if (!graph_utils::are_shapes_equal(order, data))
return false;
return true;
@@ -156,8 +158,7 @@ void remove_single_input_node(std::shared_ptr<ov::Node> node) {
if (!node_parent) {
THROW_GNA_EXCEPTION << "The removing node has no parrent node";
}
if (input_node_shape.size() != output_node_shape.size() ||
!std::equal(input_node_shape.begin(), input_node_shape.end(), output_node_shape.begin())) {
if (!graph_utils::are_shapes_equal(input_node_shape, output_node_shape)) {
auto reshape_const_node =
std::make_shared<Constant>(ov::element::i64, ov::Shape{output_node_shape.size()}, output_node_shape);
node_parent = std::make_shared<Reshape>(node_parent, reshape_const_node, false);

View File

@@ -43,7 +43,8 @@ TEST(TSSplit, Backward) {
auto reshape1 = std::make_shared<Reshape>(input_params, reshape_const1, false);
auto gather = make_gather(reshape1, TSSplit_Backward_indexes, /* axis */ 1);
auto split_axis = Constant::create(element::i64, ov::Shape{}, ov::Shape{1});
auto split = std::make_shared<Split>(gather, split_axis, 1);
auto split_lengths = Constant::create(element::i64, ov::Shape{1}, ov::Shape{8});
auto split = std::make_shared<VariadicSplit>(gather, split_axis, split_lengths);
auto reshape_const2 = Constant::create(element::i64, ov::Shape{4}, ov::Shape{1, 1, 2, 4});
auto reshape2 = std::make_shared<Reshape>(split->output(0), reshape_const2, false);
const auto result = std::make_shared<Result>(reshape2);