[GNA] Fixed Transpose replacement with Gather (#18624)
This commit is contained in:
@@ -19,154 +19,120 @@
|
||||
using namespace ov::intel_gna;
|
||||
using namespace ov::intel_gna::pass;
|
||||
using namespace ov::intel_gna::limitations;
|
||||
using namespace ov::opset9;
|
||||
using namespace ov::opset12;
|
||||
using namespace ov::pass::pattern;
|
||||
using namespace gather_sinking;
|
||||
|
||||
namespace {
|
||||
|
||||
using NodePtr = std::shared_ptr<ov::Node>;
|
||||
using NodePair = std::pair<NodePtr, NodePtr>;
|
||||
std::vector<std::shared_ptr<ov::Node>> gather_sink_forward(std::shared_ptr<ov::Node> transpose,
|
||||
std::shared_ptr<ov::Node> reshape) {
|
||||
std::vector<std::shared_ptr<ov::Node>> new_nodes = {};
|
||||
|
||||
std::vector<size_t> CreateGatherIndices(const ov::Shape& input_shape, const ov::Shape& order) {
|
||||
if (input_shape.size() < 2 || input_shape.size() > 4) {
|
||||
THROW_GNA_EXCEPTION << "Usupported shape size: " << input_shape.size();
|
||||
auto transpose_const = ov::as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
const auto gather_indexes_value =
|
||||
graph_utils::make_gather_indexes_from_transpose_axes(transpose->get_input_shape(0),
|
||||
transpose_const->get_axis_vector_val());
|
||||
ov::Shape shape_out = reshape->get_output_shape(0);
|
||||
|
||||
// Set Gather input shape
|
||||
std::vector<int8_t> reshape_in_dims = {1, -1};
|
||||
for (size_t i = reshape_in_dims.size(); i < shape_out.size(); ++i) {
|
||||
if (shape_out[i] == 1) {
|
||||
reshape_in_dims.push_back(1);
|
||||
}
|
||||
}
|
||||
// reshape
|
||||
auto reshape_in_const =
|
||||
std::make_shared<Constant>(ov::element::i64, ov::Shape{reshape_in_dims.size()}, reshape_in_dims);
|
||||
auto reshape_in = reshape->clone_with_new_inputs({transpose->input_value(0), reshape_in_const});
|
||||
new_nodes.push_back(reshape_in);
|
||||
|
||||
const int64_t gather_axis_value = graph_utils::get_first_valuable_dim_id(reshape_in->get_output_shape(0));
|
||||
auto gather_axis = std::make_shared<Constant>(ov::element::i64, ov::Shape{}, gather_axis_value);
|
||||
auto gather_indices =
|
||||
std::make_shared<Constant>(ov::element::i64, ov::Shape{gather_indexes_value.size()}, gather_indexes_value);
|
||||
auto gather = std::make_shared<Gather>(reshape_in, gather_indices, gather_axis);
|
||||
new_nodes.push_back(gather);
|
||||
|
||||
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
|
||||
auto reshape_out = std::make_shared<Reshape>(gather, reshape_out_const, false);
|
||||
|
||||
if (!graph_utils::are_shapes_equal(reshape_out->get_input_shape(0), reshape_out->get_output_shape(0))) {
|
||||
new_nodes.push_back(reshape_out);
|
||||
}
|
||||
|
||||
ov::Shape input_shape_4d = input_shape;
|
||||
ov::Shape order_4d = order;
|
||||
// Just to simplify the code we transform all shapes to 4d by adding 1 dimentions at the end
|
||||
while (input_shape_4d.size() < 4) {
|
||||
input_shape_4d.push_back(1);
|
||||
order_4d.push_back(order_4d.size());
|
||||
}
|
||||
ov::Shape output_shape_4d = graph_utils::transpose_shape(input_shape_4d, order_4d);
|
||||
ov::replace_output_update_name(reshape->output(0), new_nodes.back()->output(0));
|
||||
ov::copy_runtime_info({reshape}, new_nodes);
|
||||
|
||||
// common case when shape is 4d
|
||||
std::vector<size_t> xyz_4d = {input_shape_4d[3] * input_shape_4d[2] * input_shape_4d[1],
|
||||
input_shape_4d[3] * input_shape_4d[2],
|
||||
input_shape_4d[3],
|
||||
1};
|
||||
return new_nodes;
|
||||
}
|
||||
|
||||
std::vector<size_t> xyz = graph_utils::transpose_shape(xyz_4d, order_4d);
|
||||
std::vector<size_t> gather_order;
|
||||
std::vector<std::shared_ptr<ov::Node>> gather_sink_backward(std::shared_ptr<ov::Node> reshape,
|
||||
std::shared_ptr<ov::Node> transpose) {
|
||||
std::vector<std::shared_ptr<ov::Node>> new_nodes = {};
|
||||
auto transpose_const = ov::as_type_ptr<Constant>(transpose->get_input_node_shared_ptr(1));
|
||||
ov::Shape shape_out = transpose->get_output_shape(0);
|
||||
|
||||
for (size_t n = 0; n < output_shape_4d[0]; ++n) {
|
||||
for (size_t i = 0; i < output_shape_4d[1]; ++i) {
|
||||
for (size_t j = 0; j < output_shape_4d[2]; ++j) {
|
||||
for (size_t k = 0; k < output_shape_4d[3]; ++k) {
|
||||
gather_order.push_back(n * xyz[0] + i * xyz[1] + j * xyz[2] + k * xyz[3]);
|
||||
}
|
||||
}
|
||||
// Set Gather input shape
|
||||
std::vector<int8_t> reshape_in_dims = {1, -1};
|
||||
for (size_t i = reshape_in_dims.size(); i < shape_out.size(); ++i) {
|
||||
if (shape_out[i] == 1) {
|
||||
reshape_in_dims.push_back(1);
|
||||
}
|
||||
}
|
||||
|
||||
return gather_order;
|
||||
}
|
||||
// reshape
|
||||
auto reshape_in_const =
|
||||
std::make_shared<Constant>(ov::element::i64, ov::Shape{reshape_in_dims.size()}, reshape_in_dims);
|
||||
auto reshape_in = reshape->clone_with_new_inputs({reshape->input_value(0), reshape_in_const});
|
||||
new_nodes.push_back(reshape_in);
|
||||
|
||||
NodePair SinkForward(NodePtr transpose, std::shared_ptr<Constant> transpose_constant, NodePtr reshape) {
|
||||
const auto gather_indices_value =
|
||||
CreateGatherIndices(transpose->get_input_shape(0), transpose_constant->get_axis_vector_val());
|
||||
|
||||
const int64_t gather_axis_value = graph_utils::get_first_valuable_dim_id(reshape->get_output_shape(0));
|
||||
|
||||
auto reshape_new = reshape->clone_with_new_inputs({transpose->input_value(0), reshape->input_value(1)});
|
||||
const int64_t gather_axis_value = graph_utils::get_first_valuable_dim_id(reshape_in->get_output_shape(0));
|
||||
const auto gather_indexes_value =
|
||||
graph_utils::make_gather_indexes_from_transpose_axes(transpose->get_input_shape(0),
|
||||
transpose_const->get_axis_vector_val());
|
||||
|
||||
auto gather_axis = std::make_shared<Constant>(ov::element::i64, ov::Shape{}, gather_axis_value);
|
||||
auto gather_indices =
|
||||
std::make_shared<Constant>(ov::element::i64, ov::Shape{gather_indices_value.size()}, gather_indices_value);
|
||||
auto gather = std::make_shared<Gather>(reshape_new, gather_indices, gather_axis);
|
||||
std::make_shared<Constant>(ov::element::i64, ov::Shape{gather_indexes_value.size()}, gather_indexes_value);
|
||||
auto gather = std::make_shared<Gather>(reshape_in, gather_indices, gather_axis);
|
||||
new_nodes.push_back(gather);
|
||||
|
||||
ov::replace_node(reshape, gather);
|
||||
auto reshape_out_const = std::make_shared<Constant>(ov::element::i64, ov::Shape{shape_out.size()}, shape_out);
|
||||
auto reshape_out = std::make_shared<Reshape>(gather, reshape_out_const, false);
|
||||
|
||||
ov::copy_runtime_info({reshape}, {gather, gather_indices, gather_axis, reshape_new});
|
||||
gather->set_friendly_name(reshape->get_friendly_name());
|
||||
|
||||
return std::make_pair(reshape_new, gather);
|
||||
}
|
||||
|
||||
NodePair SinkBackward(NodePtr transpose, std::shared_ptr<Constant> transpose_constant, NodePtr reshape) {
|
||||
const int64_t gather_axis_value = graph_utils::get_first_valuable_dim_id(reshape->get_input_shape(0));
|
||||
const auto gather_indices_value =
|
||||
CreateGatherIndices(transpose->get_input_shape(0), transpose_constant->get_axis_vector_val());
|
||||
|
||||
auto gather_axis = std::make_shared<Constant>(ov::element::i64, ov::Shape{}, gather_axis_value);
|
||||
auto gather_indices =
|
||||
std::make_shared<Constant>(ov::element::i64, ov::Shape{gather_indices_value.size()}, gather_indices_value);
|
||||
auto gather = std::make_shared<Gather>(reshape->input_value(0), gather_indices, gather_axis);
|
||||
|
||||
auto reshape_const_new = std::make_shared<Constant>(ov::element::i64,
|
||||
ov::Shape{transpose->get_output_shape(0).size()},
|
||||
transpose->get_output_shape(0));
|
||||
auto reshape_new = std::make_shared<Reshape>(gather, reshape_const_new, false);
|
||||
|
||||
ov::replace_node(transpose, reshape_new);
|
||||
|
||||
ov::copy_runtime_info({transpose}, {gather, gather_indices, gather_axis, reshape_new, reshape_const_new});
|
||||
reshape_new->set_friendly_name(transpose->get_friendly_name());
|
||||
|
||||
return std::make_pair(gather, reshape_new);
|
||||
}
|
||||
|
||||
bool AreFlattenShapes(const ov::Shape& shape1, const ov::Shape& shape2) {
|
||||
size_t i = 0;
|
||||
// find non-equal parts
|
||||
while (shape1[i] == shape2[i]) {
|
||||
++i;
|
||||
if (!graph_utils::are_shapes_equal(reshape_out->get_input_shape(0), reshape_out->get_output_shape(0))) {
|
||||
new_nodes.push_back(reshape_out);
|
||||
}
|
||||
// consider only last dimension to be flatten/unflatten
|
||||
if (shape1.size() - 1 != i && shape2.size() - 1 != i)
|
||||
return false;
|
||||
// min_shape.back() == MULTIPLY(max_shape.begin() + i, max_shape.end())
|
||||
const size_t mult1 = std::accumulate(shape1.begin() + i, shape1.end(), std::size_t{1}, std::multiplies<size_t>());
|
||||
const size_t mult2 = std::accumulate(shape2.begin() + i, shape2.end(), std::size_t{1}, std::multiplies<size_t>());
|
||||
return mult1 == mult2;
|
||||
}
|
||||
|
||||
bool IsTailFlatten(const ov::Output<ov::Node>& output) {
|
||||
std::shared_ptr<ov::Node> reshape_node = output.get_node_shared_ptr();
|
||||
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
|
||||
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
|
||||
return false;
|
||||
const ov::Shape input_shape = graph_utils::trim_shape(reshape_node->get_input_shape(0));
|
||||
const ov::Shape output_shape = graph_utils::trim_shape(reshape_node->get_output_shape(0));
|
||||
return input_shape.size() > output_shape.size() && AreFlattenShapes(input_shape, output_shape);
|
||||
}
|
||||
ov::replace_node_update_name(transpose, new_nodes.back());
|
||||
ov::copy_runtime_info({transpose}, new_nodes);
|
||||
|
||||
bool IsTailUnflatten(const ov::Output<ov::Node>& output) {
|
||||
std::shared_ptr<ov::Node> reshape_node = output.get_node_shared_ptr();
|
||||
if (reshape_node->get_output_partial_shape(0).rank().is_dynamic() ||
|
||||
reshape_node->get_input_partial_shape(0).rank().is_dynamic())
|
||||
return false;
|
||||
const ov::Shape input_shape = graph_utils::trim_shape(reshape_node->get_input_shape(0));
|
||||
const ov::Shape output_shape = graph_utils::trim_shape(reshape_node->get_output_shape(0));
|
||||
return input_shape.size() < output_shape.size() && AreFlattenShapes(input_shape, output_shape);
|
||||
return new_nodes;
|
||||
}
|
||||
|
||||
bool is_transpose_unsupported(const ov::Output<ov::Node>& output) {
|
||||
return !Limitations::is_transpose_supported(output.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
bool IfBackwardSinkingEnabled(const ov::Output<ov::Node>& output) {
|
||||
bool is_backward_sinking_enabled(const ov::Output<ov::Node>& output) {
|
||||
return is_transpose_unsupported(output) && ov::is_sinking_node(output.get_node_shared_ptr());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// working with situation when we transpose dims that are flatten/unflatten
|
||||
// consider only if flatten/unflatten are last dimensions
|
||||
GatherSinkingTransposeReshapeForward::GatherSinkingTransposeReshapeForward() {
|
||||
MATCHER_SCOPE(GatherSinkingTransposeReshapeForward);
|
||||
|
||||
auto transpose_const_label = wrap_type<Constant>();
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), transpose_const_label}, is_transpose_unsupported);
|
||||
auto reshape_label = wrap_type<Reshape>({transpose_label, any_input()}, IsTailFlatten);
|
||||
auto reshape_label = wrap_type<Reshape>({transpose_label, any_input()});
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto transpose_const =
|
||||
ov::as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto reshape = pattern_to_output.at(reshape_label).get_node_shared_ptr();
|
||||
|
||||
const ov::Shape reshape_shape = graph_utils::trim_shape(reshape->get_shape());
|
||||
@@ -176,12 +142,12 @@ GatherSinkingTransposeReshapeForward::GatherSinkingTransposeReshapeForward() {
|
||||
return true;
|
||||
}
|
||||
|
||||
const NodePair new_nodes = SinkForward(transpose, transpose_const, reshape);
|
||||
const std::vector<std::shared_ptr<ov::Node>> new_nodes = gather_sink_forward(transpose, reshape);
|
||||
|
||||
register_new_node(new_nodes.first);
|
||||
register_new_node(new_nodes.second);
|
||||
for (const auto& node : new_nodes) {
|
||||
register_new_node(node);
|
||||
}
|
||||
|
||||
update_forward_gather_sinking_ability(new_nodes.second);
|
||||
return true;
|
||||
};
|
||||
|
||||
@@ -192,15 +158,15 @@ GatherSinkingTransposeReshapeForward::GatherSinkingTransposeReshapeForward() {
|
||||
GatherSinkingTransposeReshapeBackward::GatherSinkingTransposeReshapeBackward() {
|
||||
MATCHER_SCOPE(GatherSinkingTransposeReshapeBackward);
|
||||
|
||||
auto reshape_label = wrap_type<Reshape>({any_input(), any_input()}, IsTailUnflatten);
|
||||
auto reshape_label = wrap_type<Reshape>({any_input(), any_input()});
|
||||
auto transpose_const_label = wrap_type<Constant>();
|
||||
auto transpose_label = wrap_type<Transpose>({reshape_label, transpose_const_label}, IfBackwardSinkingEnabled);
|
||||
auto transpose_label = wrap_type<Transpose>({reshape_label, transpose_const_label}, is_backward_sinking_enabled);
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
|
||||
auto reshape = pattern_to_output.at(reshape_label).get_node_shared_ptr();
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
|
||||
const ov::Shape reshape_shape = graph_utils::trim_shape(reshape->get_input_shape(0));
|
||||
const ov::Shape transpose_shape = graph_utils::trim_shape(transpose->get_shape());
|
||||
@@ -209,9 +175,10 @@ GatherSinkingTransposeReshapeBackward::GatherSinkingTransposeReshapeBackward() {
|
||||
return true;
|
||||
}
|
||||
|
||||
const NodePair new_nodes = SinkBackward(transpose, transpose_const, reshape);
|
||||
register_new_node(new_nodes.first);
|
||||
register_new_node(new_nodes.second);
|
||||
const std::vector<std::shared_ptr<ov::Node>> new_nodes = gather_sink_backward(reshape, transpose);
|
||||
for (const auto& node : new_nodes) {
|
||||
register_new_node(node);
|
||||
}
|
||||
|
||||
return true;
|
||||
};
|
||||
|
||||
@@ -163,7 +163,8 @@ HandleTransposeBeforeMatMul::HandleTransposeBeforeMatMul() {
|
||||
}
|
||||
|
||||
if (prev_node) {
|
||||
if (Limitations::is_transpose_supported(prev_node->get_output_shape(0))) {
|
||||
if (graph_utils::is_shape_2d(prev_node->get_output_shape(0)) &&
|
||||
Limitations::is_transpose_supported(prev_node->get_output_shape(0))) {
|
||||
InsertTranspose(prev_node, matmul_node->get_friendly_name(), true);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ TEST(GatherSinkingTransposeReshape, ForwardSinking3D) {
|
||||
auto input_params = std::make_shared<Parameter>(element::Type_t::f32, Shape{1, 1, 14, 4});
|
||||
auto tanh0 = std::make_shared<Tanh>(input_params);
|
||||
|
||||
auto reshape_const = std::make_shared<Constant>(element::i64, Shape{2}, std::vector<int>{1, 56});
|
||||
auto reshape_const = std::make_shared<Constant>(element::i64, Shape{2}, std::vector<int>{1, -1});
|
||||
auto reshape = std::make_shared<Reshape>(tanh0, reshape_const, false);
|
||||
|
||||
auto generate_indices = []() -> std::vector<int64_t> {
|
||||
@@ -121,7 +121,7 @@ TEST(GatherSinkingTransposeReshape, ForwardSinking3D) {
|
||||
const FunctionsComparator func_comparator =
|
||||
FunctionsComparator::with_default().enable(FunctionsComparator::ATTRIBUTES);
|
||||
const FunctionsComparator::Result result = func_comparator(function, reference_function);
|
||||
ASSERT_TRUE(result.valid);
|
||||
ASSERT_TRUE(result.valid) << result.message;
|
||||
}
|
||||
|
||||
TEST(GatherSinkingTransposeReshape, BackwardSinking) {
|
||||
|
||||
Reference in New Issue
Block a user