Merge branch 'transpose_sinking_fakequantize' into gna_layout_debug

This commit is contained in:
Evgeny Kotov 2023-03-21 19:48:06 +01:00
commit 02abf9b1f0

View File

@ -8,6 +8,7 @@
#include "itt.hpp"
#include "openvino/opsets/opset10.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/rt_info/transpose_sinking_attr.hpp"
#include "transformations/transpose_sinking/ts_utils.hpp"
@ -19,6 +20,7 @@ using namespace ov::pass::pattern;
using namespace ov::op::util;
using namespace ov::pass::transpose_sinking;
using namespace ov::pass::transpose_sinking::utils;
using namespace ov::pass::pattern::op;
namespace {
@ -40,41 +42,45 @@ NodePair SwapNodes(const NodePtr& first_node, const NodePtr& second_node) {
auto first_node_inputs = first_node->input_values();
first_node_inputs[0] = new_first_node;
auto new_second_node = first_node->clone_with_new_inputs(first_node_inputs);
new_second_node->set_friendly_name(second_node->get_friendly_name());
ov::copy_runtime_info({first_node, second_node}, {new_first_node, new_second_node});
ov::copy_runtime_info({first_node, second_node}, {new_first_node, new_second_node});
ov::replace_node(second_node, new_second_node);
return std::make_pair(new_first_node, new_second_node);
}
NodePtr GetPatternNode(const PatternValueMap& pattern_to_output, const NodeVector& nodes) {
for (const auto& node : nodes) {
auto it = pattern_to_output.find(node);
if (it == pattern_to_output.end())
continue;
return it->second.get_node_shared_ptr();
}
return {};
}
} // namespace
TSUnaryForward::TSUnaryForward() {
MATCHER_SCOPE(TSUnaryForward);
auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
auto fq_label = wrap_type<FakeQuantize>({transpose_label, any_input(), any_input(), any_input(), any_input()});
auto unary_op_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>({transpose_label});
auto unary_label = std::make_shared<pattern::op::Or>(OutputVector{fq_label, unary_op_label});
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>(
{transpose_label});
auto unary_label = std::make_shared<Or>(OutputVector{fq_label, unary_op_label});
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = GetPatternNode(pattern_to_output, NodeVector{unary_op_label, fq_label});
const NodePair new_nodes = SwapNodes(transpose, unary);
register_new_node(new_nodes.first);
register_new_node(new_nodes.second);
UpdateForwardSinkingAbility(new_nodes.second);
return true;
};
auto m = std::make_shared<Matcher>(unary_label, matcher_name);
register_matcher(m, matcher_pass_callback);
}
@ -87,28 +93,23 @@ bool IfSinkingEnabled(const Output<Node>& output) {
TSUnaryBackward::TSUnaryBackward() {
MATCHER_SCOPE(TSUnaryBackwardMultiConsumers);
auto unary_restrictions = [](const Output<Node>& output) -> bool {
return HasSameOutputTransposeNodes(output);
};
auto fq_label =
wrap_type<FakeQuantize>({any_input(), any_input(), any_input(), any_input(), any_input()}, unary_restrictions);
auto unary_op_label =
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>({any_input()},
unary_restrictions);
auto unary_label = std::make_shared<pattern::op::Or>(OutputVector{fq_label, unary_op_label});
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>(
{any_input()},
unary_restrictions);
auto unary_label = std::make_shared<Or>(OutputVector{fq_label, unary_op_label});
auto transpose_const_label = wrap_type<Constant>();
auto transpose_label = wrap_type<Transpose>({unary_label, transpose_const_label}, IfSinkingEnabled);
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
const auto& pattern_to_output = m.get_pattern_value_map();
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
auto unary = GetPatternNode(pattern_to_output, NodeVector{unary_op_label, fq_label});
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary,
transpose_const,
/* input_indexes */ {0})) {