Merge branch 'transpose_sinking_fakequantize' into gna_layout_debug
This commit is contained in:
commit
02abf9b1f0
@ -8,6 +8,7 @@
|
||||
|
||||
#include "itt.hpp"
|
||||
#include "openvino/opsets/opset10.hpp"
|
||||
#include "openvino/pass/pattern/op/or.hpp"
|
||||
#include "openvino/pass/pattern/op/wrap_type.hpp"
|
||||
#include "transformations/rt_info/transpose_sinking_attr.hpp"
|
||||
#include "transformations/transpose_sinking/ts_utils.hpp"
|
||||
@ -19,6 +20,7 @@ using namespace ov::pass::pattern;
|
||||
using namespace ov::op::util;
|
||||
using namespace ov::pass::transpose_sinking;
|
||||
using namespace ov::pass::transpose_sinking::utils;
|
||||
using namespace ov::pass::pattern::op;
|
||||
|
||||
namespace {
|
||||
|
||||
@ -40,41 +42,45 @@ NodePair SwapNodes(const NodePtr& first_node, const NodePtr& second_node) {
|
||||
|
||||
auto first_node_inputs = first_node->input_values();
|
||||
first_node_inputs[0] = new_first_node;
|
||||
|
||||
auto new_second_node = first_node->clone_with_new_inputs(first_node_inputs);
|
||||
|
||||
new_second_node->set_friendly_name(second_node->get_friendly_name());
|
||||
ov::copy_runtime_info({first_node, second_node}, {new_first_node, new_second_node});
|
||||
|
||||
ov::copy_runtime_info({first_node, second_node}, {new_first_node, new_second_node});
|
||||
ov::replace_node(second_node, new_second_node);
|
||||
|
||||
return std::make_pair(new_first_node, new_second_node);
|
||||
}
|
||||
|
||||
NodePtr GetPatternNode(const PatternValueMap& pattern_to_output, const NodeVector& nodes) {
|
||||
for (const auto& node : nodes) {
|
||||
auto it = pattern_to_output.find(node);
|
||||
if (it == pattern_to_output.end())
|
||||
continue;
|
||||
return it->second.get_node_shared_ptr();
|
||||
}
|
||||
return {};
|
||||
}
|
||||
} // namespace
|
||||
|
||||
TSUnaryForward::TSUnaryForward() {
|
||||
MATCHER_SCOPE(TSUnaryForward);
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({any_input(), any_input()});
|
||||
auto fq_label = wrap_type<FakeQuantize>({transpose_label, any_input(), any_input(), any_input(), any_input()});
|
||||
auto unary_op_label =
|
||||
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>({transpose_label});
|
||||
auto unary_label = std::make_shared<pattern::op::Or>(OutputVector{fq_label, unary_op_label});
|
||||
|
||||
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>(
|
||||
{transpose_label});
|
||||
auto unary_label = std::make_shared<Or>(OutputVector{fq_label, unary_op_label});
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto unary = GetPatternNode(pattern_to_output, NodeVector{unary_op_label, fq_label});
|
||||
|
||||
const NodePair new_nodes = SwapNodes(transpose, unary);
|
||||
|
||||
register_new_node(new_nodes.first);
|
||||
register_new_node(new_nodes.second);
|
||||
|
||||
UpdateForwardSinkingAbility(new_nodes.second);
|
||||
return true;
|
||||
};
|
||||
|
||||
auto m = std::make_shared<Matcher>(unary_label, matcher_name);
|
||||
register_matcher(m, matcher_pass_callback);
|
||||
}
|
||||
@ -87,28 +93,23 @@ bool IfSinkingEnabled(const Output<Node>& output) {
|
||||
|
||||
TSUnaryBackward::TSUnaryBackward() {
|
||||
MATCHER_SCOPE(TSUnaryBackwardMultiConsumers);
|
||||
|
||||
auto unary_restrictions = [](const Output<Node>& output) -> bool {
|
||||
return HasSameOutputTransposeNodes(output);
|
||||
};
|
||||
|
||||
auto fq_label =
|
||||
wrap_type<FakeQuantize>({any_input(), any_input(), any_input(), any_input(), any_input()}, unary_restrictions);
|
||||
auto unary_op_label =
|
||||
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>({any_input()},
|
||||
unary_restrictions);
|
||||
auto unary_label = std::make_shared<pattern::op::Or>(OutputVector{fq_label, unary_op_label});
|
||||
|
||||
wrap_type<UnaryElementwiseArithmetic, Clamp, Elu, SoftPlus, LogicalNot, Convert, IsInf, IsNaN, IsFinite>(
|
||||
{any_input()},
|
||||
unary_restrictions);
|
||||
auto unary_label = std::make_shared<Or>(OutputVector{fq_label, unary_op_label});
|
||||
auto transpose_const_label = wrap_type<Constant>();
|
||||
|
||||
auto transpose_label = wrap_type<Transpose>({unary_label, transpose_const_label}, IfSinkingEnabled);
|
||||
|
||||
ov::matcher_pass_callback matcher_pass_callback = [=](Matcher& m) {
|
||||
const auto& pattern_to_output = m.get_pattern_value_map();
|
||||
auto transpose_const = as_type_ptr<Constant>(pattern_to_output.at(transpose_const_label).get_node_shared_ptr());
|
||||
auto transpose = pattern_to_output.at(transpose_label).get_node_shared_ptr();
|
||||
auto unary = GetPatternNode(pattern_to_output, NodeVector{unary_op_label, fq_label});
|
||||
|
||||
for (auto& new_node : sink_backward::InsertTransposeBeforeNode(unary,
|
||||
transpose_const,
|
||||
/* input_indexes */ {0})) {
|
||||
|
Loading…
Reference in New Issue
Block a user